From 7af4cb647598cbeb9ab3eaf03343d47a2643a1a2 Mon Sep 17 00:00:00 2001 From: yxs Date: Wed, 1 Apr 2026 23:24:57 -0600 Subject: [PATCH] [fix] handle FSDP DTensor in broadcast_from_megatron_pp Megatron FSDP (ZeRO-3) stores parameters as DTensors. When export_weights broadcasts params across PP ranks, torch.distributed.broadcast() triggers DTensor dispatch and fails because the PP group is not in the DTensor's DeviceMesh. Fix: call DTensor.full_tensor() to materialize the full parameter before broadcasting. --- mbridge/core/util.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/mbridge/core/util.py b/mbridge/core/util.py index 57d9671..9a47e64 100644 --- a/mbridge/core/util.py +++ b/mbridge/core/util.py @@ -268,6 +268,14 @@ def unwrap_model(model, module_instances=ALL_MODULE_WRAPPER_CLASSNAMES): def broadcast_from_megatron_pp(tensor: torch.Tensor): # tensor is not None only in one of the pp ranks if tensor is not None: + # FSDP DTensor: gather full tensor before PP broadcast. + try: + from torch.distributed.tensor import DTensor + + if isinstance(tensor, DTensor): + tensor = tensor.full_tensor() + except ImportError: + pass shape = tensor.shape dtype = tensor.dtype tensor_parallel = getattr(tensor, "tensor_model_parallel", None)