From e9c94c3b3c99308c5dfc3db53e9979c4189117a9 Mon Sep 17 00:00:00 2001 From: Shunting Zhang Date: Mon, 6 Apr 2026 16:25:11 -0700 Subject: [PATCH] warning for process group name not found conditionally stack-info: PR: https://github.com/pytorch/helion/pull/1973, branch: shunting314/stack/35 --- helion/_dist_utils.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/helion/_dist_utils.py b/helion/_dist_utils.py index 7bd156fcc..ad7f0bf4c 100644 --- a/helion/_dist_utils.py +++ b/helion/_dist_utils.py @@ -210,9 +210,16 @@ def _find_process_group_name(fn: Callable, args: tuple[object, ...]) -> str | No assert isinstance(arg, str), f"{type(arg)}" return arg - warning(ProcessGroupNameNotFound) assert dist.group.WORLD is not None - return dist.group.WORLD.group_name + group_name = dist.group.WORLD.group_name + has_symm_mem_tensor = any( + is_symm_mem_tensor(arg, group_name) + for arg in args + if isinstance(arg, torch.Tensor) + ) + if has_symm_mem_tensor: + warning(ProcessGroupNameNotFound) + return group_name def _clone_symm_mem_tensor(