Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 83 additions & 20 deletions lib/python/flame/backend/shm.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,29 +188,92 @@ async def _create_join_inner_task():

def leave(self, channel) -> None:
"""Leave a given channel.

TODO: notify the sockmap manager to remove the entry from eBPF map

Gracefully cleanup shared memory buffers associated with the channel.
Handles FileNotFoundError if segments are already unlinked.
"""
logger.info("Clean up shared memory buffers.")

for end in channel.all_ends():
shm_buf = shared_memory.SharedMemory(name = end)
shm_buf.close()
if end == self._id:
shm_buf.unlink()

# NOTE: this method may recreate the shm dict.
shm_ends = SharedMemoryDict(name = channel.name() + "-" + channel.my_role(), size = SHM_DICT_SIZE)
del shm_ends[self._id]

if len(shm_ends) == 0:
shm_ends.shm.close()
shm_ends.shm.unlink()
del shm_ends

# NOTE: this method may recreate the shm dict.
other_ends = SharedMemoryDict(name = channel.name() + "-" + channel.other_role(), size = SHM_DICT_SIZE)
other_ends.shm.close()
try:
shm_buf = shared_memory.SharedMemory(name=end)
# Close the shared memory segment
shm_buf.close()
# If this end belongs to us, we created it, so we can unlink it
if end == self._id:
try:
shm_buf.unlink()
except FileNotFoundError:
logger.debug(f"Shared memory segment {end} already unlinked.")
except FileNotFoundError:
logger.debug(f"Shared memory segment {end} not found during cleanup (already removed?).")

# Clean up the local ends dictionary
shm_ends_name = channel.name() + "-" + channel.my_role()
try:
shm_ends = SharedMemoryDict(name=shm_ends_name, size=SHM_DICT_SIZE)
# Remove our entry
if self._id in shm_ends:
del shm_ends[self._id]

if len(shm_ends) == 0:
shm_ends.shm.close()
try:
shm_ends.shm.unlink()
except FileNotFoundError:
logger.debug(f"Shared memory dict {shm_ends_name} already unlinked.")
del shm_ends
else:
# Just close shm but do not unlink if entries remain
shm_ends.shm.close()
except FileNotFoundError:
logger.debug(f"No shared memory dict found for {shm_ends_name}; may have been unlinked already.")

# Clean up the peer ends dictionary
other_ends_name = channel.name() + "-" + channel.other_role()
try:
other_ends = SharedMemoryDict(name=other_ends_name, size=SHM_DICT_SIZE)
other_ends.shm.close()
# We don't unlink here unless we know we're the last one since
# this dictionary may still be needed by other participants
# If you have logic to determine you are the final owner, add unlink here.
#
# If you do want to attempt unlink:
# try:
# other_ends.shm.unlink()
# except FileNotFoundError:
# logger.debug(f"Shared memory dict {other_ends_name} already unlinked.")

except FileNotFoundError:
logger.debug(f"No shared memory dict found for {other_ends_name}; may have been unlinked already.")

logger.debug("channel leave completed gracefully")

# def leave(self, channel) -> None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you please remove these commented lines?

# """Leave a given channel.

# TODO: notify the sockmap manager to remove the entry from eBPF map
# """
# logger.info("Clean up shared memory buffers.")

# for end in channel.all_ends():
# shm_buf = shared_memory.SharedMemory(name = end)
# shm_buf.close()
# if end == self._id:
# shm_buf.unlink()

# # NOTE: this method may recreate the shm dict.
# shm_ends = SharedMemoryDict(name = channel.name() + "-" + channel.my_role(), size = SHM_DICT_SIZE)
# del shm_ends[self._id]

# if len(shm_ends) == 0:
# shm_ends.shm.close()
# shm_ends.shm.unlink()
# del shm_ends

# # NOTE: this method may recreate the shm dict.
# other_ends = SharedMemoryDict(name = channel.name() + "-" + channel.other_role(), size = SHM_DICT_SIZE)
# other_ends.shm.close()

def create_tx_task(
self, channel_name: str, end_id: str, comm_type=CommType.UNICAST
Expand Down
44 changes: 12 additions & 32 deletions lib/python/flame/channel_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
"""Channel manager."""

import asyncio
import atexit
Expand All @@ -31,17 +30,7 @@


def custom_excepthook(exc_type, exc_value, exc_traceback):
"""Implement a custom exception hook.

NOTE: this custom version is implemented due to the following warning
message printed at the end of execution:
"Error in sys.excepthook:

Original exception was:"
This is caused by _inner() function in cleanup().
A root-cause is not identified. As a workaround, this custom hook is
implemented and set to sys.excepthook
"""
"""Implement a custom exception hook."""
logger.critical(
"Uncaught exception:", exc_info=(exc_type, exc_value, exc_traceback)
)
Expand Down Expand Up @@ -69,7 +58,7 @@ class ChannelManager(object):
def __new__(cls):
"""Create a singleton instance."""
if cls._instance is None:
logger.info("creating a ChannelManager instance")
logger.info("Creating a ChannelManager instance")
cls._instance = super(ChannelManager, cls).__new__(cls)
return cls._instance

Expand All @@ -87,37 +76,31 @@ def __call__(self, config: Config):
atexit.register(self.cleanup)

def _setup_backends(self):
distinct_backends = {}
distinct_backends = {}

for ch_name, channel in self._config.channels.items():
# rename backend in channel config as sort to avoid confusion
sort = channel.backend
if not sort:
# channel doesn't have its own backend, nothing to do
logger.warning(f"No backend specified for channel {ch_name}.")
continue

if sort not in distinct_backends:
# Create a new backend instance if it doesn't exist
backend = backend_provider.get(sort)
broker_host = channel.broker_host or self._config.brokers.sort_to_host[sort]
broker_host = channel.broker_host or self._config.brokers.sort_to_host.get(sort)

backend.configure(broker_host, self._job_id, self._task_id)

distinct_backends[sort] = backend

# Assign the backend instance to the channel
self._backends[ch_name] = distinct_backends[sort]

if len(self._backends) == len(self._config.channels):
# every channel has its own backend
# no need to have a default backend
return

# set up a default backend
sort = self._config.backend
if sort not in distinct_backends:
self._backend = backend_provider.get(sort)
broker_host = self._config.brokers.sort_to_host[sort]
broker_host = self._config.brokers.sort_to_host.get(sort)
self._backend.configure(broker_host, self._job_id, self._task_id)
else:
self._backend = distinct_backends[sort]
Expand Down Expand Up @@ -152,7 +135,7 @@ def join(self, name: str) -> bool:
if name in self._backends:
backend = self._backends[name]
else:
logger.info(f"no backend found for channel {name}; use default")
logger.info(f"No backend found for channel {name}; using default backend.")
backend = self._backend

self._channels[name] = Channel(
Expand All @@ -165,8 +148,6 @@ def leave(self, name):
if not self.is_joined(name):
return

# TODO: leave() is only implemented for p2p backend;
# implement it completely for mqtt backend
self._channels[name].leave()
del self._channels[name]

Expand All @@ -182,7 +163,6 @@ def get_by_tag(self, tag: str) -> Optional[Channel]:
def get(self, name: str) -> Optional[Channel]:
"""Return a channel object in a given channel name."""
if not self.is_joined(name):
# didn't join the channel yet
return None

return self._channels[name]
Expand All @@ -193,9 +173,9 @@ def is_joined(self, name):

def cleanup(self):
"""Clean up pending asyncio tasks."""
logger.debug("calling cleanup")
logger.debug("Calling cleanup")
for _, ch in self._channels.items():
logger.debug(f"calling leave for channel {ch.name()}")
logger.debug(f"Calling leave for channel {ch.name()}")
ch.leave()

async def _inner(backend):
Expand All @@ -204,12 +184,12 @@ async def _inner(backend):
try:
await task
except asyncio.CancelledError:
logger.debug(f"successfully cancelled {task.get_name()}")
logger.debug(f"Successfully cancelled {task.get_name()}")

logger.debug("done with cleaning up asyncio tasks")
logger.debug("Done cleaning up asyncio tasks")

if self._backend:
_ = run_async(_inner(self._backend), self._backend.loop())

for k, v in self._backends.items():
_ = run_async(_inner(v), v.loop())
_ = run_async(_inner(v), v.loop())
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: please add a new line.