Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions swift/megatron/utils/megatron_lm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,11 +403,11 @@ def load_mcore_checkpoint(args,
tracker_path = os.path.join(load_dir, 'latest_checkpointed_iteration.txt')
iteration = _load_iteration(tracker_path)
checkpoint_dir = os.path.join(load_dir, f'iter_{iteration:07d}')
state_dict = dist_checkpointing.load_common_state_dict(checkpoint_dir)
common_state_dict = dist_checkpointing.load_common_state_dict(checkpoint_dir)

ckpt_tp_pp = (
state_dict['args'].tensor_model_parallel_size,
state_dict['args'].pipeline_model_parallel_size,
common_state_dict['args'].tensor_model_parallel_size,
common_state_dict['args'].pipeline_model_parallel_size,
)
run_tp_pp = (
args.tensor_model_parallel_size,
Expand All @@ -416,14 +416,14 @@ def load_mcore_checkpoint(args,
mismatch_msg = f'(TP, PP) mismatch after resume ({run_tp_pp} vs {ckpt_tp_pp} from checkpoint)'
# Determine if RNG state will be loaded
if (ckpt_tp_pp == run_tp_pp and not finetune and not no_load_rng
and not getattr(state_dict['args'], 'no_save_rng', False)):
and not getattr(common_state_dict['args'], 'no_save_rng', False)):
gen_sd_rng_state = _get_rng_state() # we can load the rng state
else:
gen_sd_rng_state = None
if ckpt_tp_pp != run_tp_pp:
logger.info(f'{mismatch_msg}: RNG state will be ignored')
sharded_sd_metadata = state_dict.get('content_metadata')
if (not finetune and not no_load_optim and not getattr(state_dict['args'], 'no_save_optim', False)):
sharded_sd_metadata = common_state_dict.get('content_metadata')
if (not finetune and not no_load_optim and not getattr(common_state_dict['args'], 'no_save_optim', False)):
gen_sd_optim = optimizer
gen_sd_opt_param_scheduler = opt_param_scheduler

Expand Down Expand Up @@ -463,16 +463,16 @@ def load_mcore_checkpoint(args,

if finetune:
iteration = 0
if 'args' in state_dict and not finetune:
args.consumed_train_samples = getattr(state_dict['args'], 'consumed_train_samples', 0)
if 'args' in common_state_dict and not finetune:
args.consumed_train_samples = getattr(common_state_dict['args'], 'consumed_train_samples', 0)

if len(ddp_models) == 1:
ddp_models[0].load_state_dict(state_dict['model'], strict=False)
else:
for i, m in enumerate(ddp_models):
if f'model{i}' not in state_dict:
continue
m.load_state_dict(state_dict[f'model{i}'])
m.load_state_dict(state_dict[f'model{i}'], strict=False)

if not finetune and not no_load_optim:
if optimizer is not None:
Expand Down
Loading