Skip to content

FSDP2训练保存的lora权重文件为空 #156

@PlutoQyl

Description

@PlutoQyl
elif self.accelerator.is_fsdp2:
            # FSDP/FSDP2
            from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict
            if state_dict_keys is not None:
                # Temporarily mark unwanted params as frozen
                # This `requires_grad` trick does not work correctly. Don't know why.
                original_state = {}
                
                # Freeze unwanted params
                for name, param in model.named_parameters():
                    original_state[name] = param.requires_grad
                    param.requires_grad = is_param_match_key(name, state_dict_keys)
                
                options = StateDictOptions(
                    full_state_dict=True,
                    broadcast_from_rank0=True,
                    cpu_offload=True,
                    ignore_frozen_params=True,
                )
                state_dict = get_model_state_dict(model, options=options)
                
                # Restore original state
                for name, param in model.named_parameters():
                    param.requires_grad = original_state[name]

看到代码有这一段,不知道作者现在是否有解决办法?我在训练Online DPO时发现这个问题(zero2在H20上训练Flux2klein-9B会爆显存)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions