diff --git a/mbridge/core/util.py b/mbridge/core/util.py index 57d9671..9a47e64 100644 --- a/mbridge/core/util.py +++ b/mbridge/core/util.py @@ -268,6 +268,14 @@ def unwrap_model(model, module_instances=ALL_MODULE_WRAPPER_CLASSNAMES): def broadcast_from_megatron_pp(tensor: torch.Tensor): # tensor is not None only in one of the pp ranks if tensor is not None: + # FSDP DTensor: gather full tensor before PP broadcast. + try: + from torch.distributed.tensor import DTensor + + if isinstance(tensor, DTensor): + tensor = tensor.full_tensor() + except ImportError: + pass shape = tensor.shape dtype = tensor.dtype tensor_parallel = getattr(tensor, "tensor_model_parallel", None)