Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions swift/megatron/arguments/megatron_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,9 @@ class MegatronArguments(RLHFMegatronArgumentsMixin, MegatronTunerMixin):
overlap_param_gather: bool = False
overlap_param_gather_with_optimizer_step: bool = False
align_grad_reduce: bool = True
# Eagerly create NCCL communicators before the training loop to avoid the lazy
# first-use allocation hitting the iteration-1 memory peak (Failed to CUDA calloc async).
nccl_comm_warmup: bool = False
virtual_pipeline_model_parallel_size: Optional[int] = None
microbatch_group_size_per_vp_stage: Optional[int] = None
pipeline_model_parallel_layout: Optional[str] = None
Expand Down
31 changes: 31 additions & 0 deletions swift/megatron/trainers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,37 @@ def setup_model_training(self):
self._saved_param_sync_func = config.param_sync_func
config.param_sync_func = None

if args.nccl_comm_warmup:
# Eagerly create NCCL communicators while GPU memory is still free. Lazily-initialized
# comms (e.g. the dp/cp loss all-reduce and grad-sync coalescing) otherwise first fire
# at the iteration-1 memory peak, where NCCL's internal cudaMalloc can fail with
# "Failed to CUDA calloc async N bytes". A 1-element dummy all-reduce per group is
# numerically inert and forces the communicator to be created up front.
dummy = torch.zeros(1, device=get_current_device())
warmed = 0
for getter, kwargs in (
(mpu.get_data_parallel_group, {
'with_context_parallel': True
}),
(mpu.get_data_parallel_group, {}),
(mpu.get_context_parallel_group, {}),
(mpu.get_tensor_model_parallel_group, {}),
(mpu.get_pipeline_model_parallel_group, {}),
(mpu.get_model_parallel_group, {}),
(mpu.get_embedding_group, {}),
(mpu.get_position_embedding_group, {}),
):
Comment on lines +626 to +637

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For improved readability and maintainability, consider extracting this list of group getters into a named constant. This makes the purpose of the list clearer and simplifies future modifications. You can then move the constant definition to the module level if you prefer.

            # List of (getter, kwargs) for all communicators to warm up.
            GROUPS_TO_WARM_UP = [
                (mpu.get_data_parallel_group, {'with_context_parallel': True}),
                (mpu.get_data_parallel_group, {}),
                (mpu.get_context_parallel_group, {}),
                (mpu.get_tensor_model_parallel_group, {}),
                (mpu.get_pipeline_model_parallel_group, {}),
                (mpu.get_model_parallel_group, {}),
                (mpu.get_embedding_group, {}),
                (mpu.get_position_embedding_group, {}),
            ]
            for getter, kwargs in GROUPS_TO_WARM_UP:

try:
group = getter(**kwargs)
except (AssertionError, ValueError, TypeError):
continue
for g in (group if isinstance(group, list) else [group]):
if g is not None:
torch.distributed.all_reduce(dummy, group=g)
warmed += 1
torch.cuda.synchronize()
logger.info(f'NCCL communicator warm-up done ({warmed} groups).')

self.call_event('on_train_begin')
self._train_metrics = {}

Expand Down
Loading