-
Notifications
You must be signed in to change notification settings - Fork 4
Open
Description
Lines 112 to 122 in 87c3e44
| fast_weights = OrderedDict((name, param - torch.mul(meta_step_size, torch.clamp(grad, 0-clip_value, clip_value))) for | |
| ((name, param), grad) in | |
| zip(model.named_parameters(), grads)) | |
| learner = copy.deepcopy(model) | |
| learner.load_state_dict(fast_weights, strict=False) | |
| output_outer = learner(image_freq) | |
| del fast_weights | |
| loss_outer = loss_fun(output_outer, class_l) | |
| loss = loss_inner + loss_outer |
Your meta learning implementation is wrong, load_state_dict() does not preserve the computation graph so loss_outer will never propagate gradient back to the original model parameters. be careful :)
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels