diff --git a/helion/_dist_utils.py b/helion/_dist_utils.py index 7bd156fcc..ad7f0bf4c 100644 --- a/helion/_dist_utils.py +++ b/helion/_dist_utils.py @@ -210,9 +210,16 @@ def _find_process_group_name(fn: Callable, args: tuple[object, ...]) -> str | No assert isinstance(arg, str), f"{type(arg)}" return arg - warning(ProcessGroupNameNotFound) assert dist.group.WORLD is not None - return dist.group.WORLD.group_name + group_name = dist.group.WORLD.group_name + has_symm_mem_tensor = any( + is_symm_mem_tensor(arg, group_name) + for arg in args + if isinstance(arg, torch.Tensor) + ) + if has_symm_mem_tensor: + warning(ProcessGroupNameNotFound) + return group_name def _clone_symm_mem_tensor(