Skip to content

Code for using the ImageNet pretrained model #146

@LightingMc

Description

@LightingMc

I thought would be helpful for other people. I had issues with getting the resnet used in this repo running properly, but the given weights work well with Pytorch's default resnet.

Loading weights

state_dict=torch.load("supcon_official.pth",'cpu')

Correcting the terms properly.

state_dict=state_dict['model']
new_state_dict = {}
for k, v in state_dict.items():
k = k.replace("module.", "")
new_state_dict[k] = v
state_dict = new_state_dict
new_state_dict = {}

for k, v in state_dict.items():
k = k.replace("encoder.", "")
new_state_dict[k] = v
state_dict = new_state_dict

Using the standard pytorch resnet50

model = resnet50()
del model.fc
model.fc = nn.Identity()

Dont need this

state_dict.pop("head.0.weight", None)
state_dict.pop("head.0.bias", None)
state_dict.pop("head.2.weight", None)
state_dict.pop("head.2.bias", None)

This should do the trick

model.load_state_dict(state_dict,strict=True)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions