elif self.accelerator.is_fsdp2:
# FSDP/FSDP2
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict
if state_dict_keys is not None:
# Temporarily mark unwanted params as frozen
# This `requires_grad` trick does not work correctly. Don't know why.
original_state = {}
# Freeze unwanted params
for name, param in model.named_parameters():
original_state[name] = param.requires_grad
param.requires_grad = is_param_match_key(name, state_dict_keys)
options = StateDictOptions(
full_state_dict=True,
broadcast_from_rank0=True,
cpu_offload=True,
ignore_frozen_params=True,
)
state_dict = get_model_state_dict(model, options=options)
# Restore original state
for name, param in model.named_parameters():
param.requires_grad = original_state[name]
看到代码有这一段,不知道作者现在是否有解决办法?我在训练Online DPO时发现这个问题(zero2在H20上训练Flux2klein-9B会爆显存)