diff --git a/swift/megatron/utils/megatron_lm_utils.py b/swift/megatron/utils/megatron_lm_utils.py index 0484d40107..5a4b16b6dd 100644 --- a/swift/megatron/utils/megatron_lm_utils.py +++ b/swift/megatron/utils/megatron_lm_utils.py @@ -403,11 +403,11 @@ def load_mcore_checkpoint(args, tracker_path = os.path.join(load_dir, 'latest_checkpointed_iteration.txt') iteration = _load_iteration(tracker_path) checkpoint_dir = os.path.join(load_dir, f'iter_{iteration:07d}') - state_dict = dist_checkpointing.load_common_state_dict(checkpoint_dir) + common_state_dict = dist_checkpointing.load_common_state_dict(checkpoint_dir) ckpt_tp_pp = ( - state_dict['args'].tensor_model_parallel_size, - state_dict['args'].pipeline_model_parallel_size, + common_state_dict['args'].tensor_model_parallel_size, + common_state_dict['args'].pipeline_model_parallel_size, ) run_tp_pp = ( args.tensor_model_parallel_size, @@ -416,14 +416,14 @@ def load_mcore_checkpoint(args, mismatch_msg = f'(TP, PP) mismatch after resume ({run_tp_pp} vs {ckpt_tp_pp} from checkpoint)' # Determine if RNG state will be loaded if (ckpt_tp_pp == run_tp_pp and not finetune and not no_load_rng - and not getattr(state_dict['args'], 'no_save_rng', False)): + and not getattr(common_state_dict['args'], 'no_save_rng', False)): gen_sd_rng_state = _get_rng_state() # we can load the rng state else: gen_sd_rng_state = None if ckpt_tp_pp != run_tp_pp: logger.info(f'{mismatch_msg}: RNG state will be ignored') - sharded_sd_metadata = state_dict.get('content_metadata') - if (not finetune and not no_load_optim and not getattr(state_dict['args'], 'no_save_optim', False)): + sharded_sd_metadata = common_state_dict.get('content_metadata') + if (not finetune and not no_load_optim and not getattr(common_state_dict['args'], 'no_save_optim', False)): gen_sd_optim = optimizer gen_sd_opt_param_scheduler = opt_param_scheduler @@ -463,8 +463,8 @@ def load_mcore_checkpoint(args, if finetune: iteration = 0 - if 'args' in state_dict and not finetune: - args.consumed_train_samples = getattr(state_dict['args'], 'consumed_train_samples', 0) + if 'args' in common_state_dict and not finetune: + args.consumed_train_samples = getattr(common_state_dict['args'], 'consumed_train_samples', 0) if len(ddp_models) == 1: ddp_models[0].load_state_dict(state_dict['model'], strict=False) @@ -472,7 +472,7 @@ def load_mcore_checkpoint(args, for i, m in enumerate(ddp_models): if f'model{i}' not in state_dict: continue - m.load_state_dict(state_dict[f'model{i}']) + m.load_state_dict(state_dict[f'model{i}'], strict=False) if not finetune and not no_load_optim: if optimizer is not None: