I thought would be helpful for other people. I had issues with getting the resnet used in this repo running properly, but the given weights work well with Pytorch's default resnet.
Loading weights
state_dict=torch.load("supcon_official.pth",'cpu')
Correcting the terms properly.
state_dict=state_dict['model']
new_state_dict = {}
for k, v in state_dict.items():
k = k.replace("module.", "")
new_state_dict[k] = v
state_dict = new_state_dict
new_state_dict = {}
for k, v in state_dict.items():
k = k.replace("encoder.", "")
new_state_dict[k] = v
state_dict = new_state_dict
Using the standard pytorch resnet50
model = resnet50()
del model.fc
model.fc = nn.Identity()
Dont need this
state_dict.pop("head.0.weight", None)
state_dict.pop("head.0.bias", None)
state_dict.pop("head.2.weight", None)
state_dict.pop("head.2.bias", None)
This should do the trick
model.load_state_dict(state_dict,strict=True)
I thought would be helpful for other people. I had issues with getting the resnet used in this repo running properly, but the given weights work well with Pytorch's default resnet.
Loading weights
state_dict=torch.load("supcon_official.pth",'cpu')
Correcting the terms properly.
state_dict=state_dict['model']
new_state_dict = {}
for k, v in state_dict.items():
k = k.replace("module.", "")
new_state_dict[k] = v
state_dict = new_state_dict
new_state_dict = {}
for k, v in state_dict.items():
k = k.replace("encoder.", "")
new_state_dict[k] = v
state_dict = new_state_dict
Using the standard pytorch resnet50
model = resnet50()
del model.fc
model.fc = nn.Identity()
Dont need this
state_dict.pop("head.0.weight", None)
state_dict.pop("head.0.bias", None)
state_dict.pop("head.2.weight", None)
state_dict.pop("head.2.bias", None)
This should do the trick
model.load_state_dict(state_dict,strict=True)