Skip to content

如何加载checkpoint继续训练? #41

@wanghao19970205

Description

@wanghao19970205

合为 pytorch_model.bin后直接加载进来执行fit,会报错flash_att相关的错
ckpt_path = os.path.join(config['checkpoint_dir'], 'pytorch_model.bin')
ckpt = torch.load(ckpt_path, map_location='cpu')
msg = trainer.model.load_state_dict(ckpt, False)
best_valid_score, best_valid_result = trainer.fit(
train_loader, None, saved=saved, show_progress=config['show_progress']
)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions