diff --git a/swift/megatron/arguments/megatron_args.py b/swift/megatron/arguments/megatron_args.py index ee933929f9..bc190536bb 100644 --- a/swift/megatron/arguments/megatron_args.py +++ b/swift/megatron/arguments/megatron_args.py @@ -546,6 +546,9 @@ class MegatronArguments(RLHFMegatronArgumentsMixin, MegatronTunerMixin): overlap_param_gather: bool = False overlap_param_gather_with_optimizer_step: bool = False align_grad_reduce: bool = True + # Eagerly create NCCL communicators before the training loop to avoid the lazy + # first-use allocation hitting the iteration-1 memory peak (Failed to CUDA calloc async). + nccl_comm_warmup: bool = False virtual_pipeline_model_parallel_size: Optional[int] = None microbatch_group_size_per_vp_stage: Optional[int] = None pipeline_model_parallel_layout: Optional[str] = None diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index b3149f81cf..cb49c3e8c6 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -615,6 +615,37 @@ def setup_model_training(self): self._saved_param_sync_func = config.param_sync_func config.param_sync_func = None + if args.nccl_comm_warmup: + # Eagerly create NCCL communicators while GPU memory is still free. Lazily-initialized + # comms (e.g. the dp/cp loss all-reduce and grad-sync coalescing) otherwise first fire + # at the iteration-1 memory peak, where NCCL's internal cudaMalloc can fail with + # "Failed to CUDA calloc async N bytes". A 1-element dummy all-reduce per group is + # numerically inert and forces the communicator to be created up front. + dummy = torch.zeros(1, device=get_current_device()) + warmed = 0 + for getter, kwargs in ( + (mpu.get_data_parallel_group, { + 'with_context_parallel': True + }), + (mpu.get_data_parallel_group, {}), + (mpu.get_context_parallel_group, {}), + (mpu.get_tensor_model_parallel_group, {}), + (mpu.get_pipeline_model_parallel_group, {}), + (mpu.get_model_parallel_group, {}), + (mpu.get_embedding_group, {}), + (mpu.get_position_embedding_group, {}), + ): + try: + group = getter(**kwargs) + except (AssertionError, ValueError, TypeError): + continue + for g in (group if isinstance(group, list) else [group]): + if g is not None: + torch.distributed.all_reduce(dummy, group=g) + warmed += 1 + torch.cuda.synchronize() + logger.info(f'NCCL communicator warm-up done ({warmed} groups).') + self.call_event('on_train_begin') self._train_metrics = {}