From eef67204edb07fad5d5ddff13474a601bd55b975 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 18 Jun 2026 11:51:59 +0800 Subject: [PATCH 1/3] fix ddp_config --- swift/megatron/utils/megatron_lm_utils.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/swift/megatron/utils/megatron_lm_utils.py b/swift/megatron/utils/megatron_lm_utils.py index b6f82feae5..d67dc7fb4d 100644 --- a/swift/megatron/utils/megatron_lm_utils.py +++ b/swift/megatron/utils/megatron_lm_utils.py @@ -517,7 +517,8 @@ def wrap_model(args, models, wrap_with_ddp: bool = True): # DDP if not wrap_with_ddp: - return + return models + kwargs = {} for f in dataclasses.fields(DistributedDataParallelConfig): if hasattr(args, f.name): @@ -525,6 +526,11 @@ def wrap_model(args, models, wrap_with_ddp: bool = True): kwargs['check_for_nan_in_grad'] = True ddp_config = DistributedDataParallelConfig(**kwargs) + # If num_buckets is set, compute bucket_size from total parameters. + num_parameters = sum(sum(p.nelement() for p in m.parameters()) for m in models) + if ddp_config.bucket_size is None and getattr(ddp_config, 'num_buckets', None) is not None: + ddp_config.bucket_size = num_parameters // ddp_config.num_buckets + # In the Megatron FSDP and DDP use path, we need to initialize the bucket size. # If bucket_size is not provided as an input, use sane default. # If using very large dp_sizes, make buckets larger to ensure that chunks used in NCCL @@ -536,7 +542,15 @@ def wrap_model(args, models, wrap_with_ddp: bool = True): if not ddp_config.overlap_grad_reduce: ddp_config.bucket_size = None - with torch.cuda.stream(torch.cuda.Stream()): + # For non-first pipeline-parallel ranks, disable bucket_size to avoid unnecessary overhead. + pp_rank = mpu.get_pipeline_model_parallel_rank() + if pp_rank > 0: + ddp_config.bucket_size = None + + # Setup stream for DDP initialization with proper synchronization. + ddp_stream = torch.cuda.Stream() + ddp_stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(ddp_stream): models = [ DDP( config=config, @@ -547,6 +561,8 @@ def wrap_model(args, models, wrap_with_ddp: bool = True): disable_bucketing=(model_chunk_idx > 0) or args.overlap_param_gather_with_optimizer_step, ) for (model_chunk_idx, model_chunk) in enumerate(models) ] + # Ensure DDP initialization completes before proceeding on the default stream. + torch.cuda.current_stream().wait_stream(ddp_stream) # Broadcast params from data parallel src rank to other data parallel ranks. if args.data_parallel_random_init: From c333bd8f57efd4a0115ba30f09393e28db3f7f0a Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 18 Jun 2026 13:40:31 +0800 Subject: [PATCH 2/3] update load_mcore_checkpoint --- swift/megatron/utils/megatron_lm_utils.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/swift/megatron/utils/megatron_lm_utils.py b/swift/megatron/utils/megatron_lm_utils.py index d67dc7fb4d..ede7f3b52a 100644 --- a/swift/megatron/utils/megatron_lm_utils.py +++ b/swift/megatron/utils/megatron_lm_utils.py @@ -403,11 +403,11 @@ def load_mcore_checkpoint(args, tracker_path = os.path.join(load_dir, 'latest_checkpointed_iteration.txt') iteration = _load_iteration(tracker_path) checkpoint_dir = os.path.join(load_dir, f'iter_{iteration:07d}') - state_dict = dist_checkpointing.load_common_state_dict(checkpoint_dir) + common_state_dict = dist_checkpointing.load_common_state_dict(checkpoint_dir) ckpt_tp_pp = ( - state_dict['args'].tensor_model_parallel_size, - state_dict['args'].pipeline_model_parallel_size, + common_state_dict['args'].tensor_model_parallel_size, + common_state_dict['args'].pipeline_model_parallel_size, ) run_tp_pp = ( args.tensor_model_parallel_size, @@ -416,14 +416,14 @@ def load_mcore_checkpoint(args, mismatch_msg = f'(TP, PP) mismatch after resume ({run_tp_pp} vs {ckpt_tp_pp} from checkpoint)' # Determine if RNG state will be loaded if (ckpt_tp_pp == run_tp_pp and not finetune and not no_load_rng - and not getattr(state_dict['args'], 'no_save_rng', False)): + and not getattr(common_state_dict['args'], 'no_save_rng', False)): gen_sd_rng_state = _get_rng_state() # we can load the rng state else: gen_sd_rng_state = None if ckpt_tp_pp != run_tp_pp: logger.info(f'{mismatch_msg}: RNG state will be ignored') - sharded_sd_metadata = state_dict.get('content_metadata') - if (not finetune and not no_load_optim and not getattr(state_dict['args'], 'no_save_optim', False)): + sharded_sd_metadata = common_state_dict.get('content_metadata') + if (not finetune and not no_load_optim and not getattr(common_state_dict['args'], 'no_save_optim', False)): gen_sd_optim = optimizer gen_sd_opt_param_scheduler = opt_param_scheduler @@ -463,8 +463,8 @@ def load_mcore_checkpoint(args, if finetune: iteration = 0 - if 'args' in state_dict and not finetune: - args.consumed_train_samples = getattr(state_dict['args'], 'consumed_train_samples', 0) + if 'args' in common_state_dict and not finetune: + args.consumed_train_samples = getattr(common_state_dict['args'], 'consumed_train_samples', 0) if len(ddp_models) == 1: ddp_models[0].load_state_dict(state_dict['model'], strict=False) @@ -472,7 +472,7 @@ def load_mcore_checkpoint(args, for i, m in enumerate(ddp_models): if f'model{i}' not in state_dict: continue - m.load_state_dict(state_dict[f'model{i}']) + m.load_state_dict(state_dict[f'model{i}'], strict=False) if not finetune and not no_load_optim: if optimizer is not None: From cbbfbe9f14712bd7490cb0e76d096056adc2ccf1 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 24 Jun 2026 15:59:04 +0800 Subject: [PATCH 3/3] fix --- swift/megatron/utils/megatron_lm_utils.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/swift/megatron/utils/megatron_lm_utils.py b/swift/megatron/utils/megatron_lm_utils.py index 572d7dbcd8..5a4b16b6dd 100644 --- a/swift/megatron/utils/megatron_lm_utils.py +++ b/swift/megatron/utils/megatron_lm_utils.py @@ -531,11 +531,6 @@ def wrap_model(args, models, wrap_with_ddp: bool = True): kwargs['check_for_nan_in_grad'] = True ddp_config = DistributedDataParallelConfig(**kwargs) - # If num_buckets is set, compute bucket_size from total parameters. - num_parameters = sum(sum(p.nelement() for p in m.parameters()) for m in models) - if ddp_config.bucket_size is None and getattr(ddp_config, 'num_buckets', None) is not None: - ddp_config.bucket_size = num_parameters // ddp_config.num_buckets - # In the Megatron FSDP and DDP use path, we need to initialize the bucket size. # If bucket_size is not provided as an input, use sane default. # If using very large dp_sizes, make buckets larger to ensure that chunks used in NCCL