forked from CSAILVision/places365
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathconvert_model.py
More file actions
24 lines (19 loc) · 829 Bytes
/
convert_model.py
File metadata and controls
24 lines (19 loc) · 829 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
from torch.autograd import Variable as V
import torchvision.models as models
from PIL import Image
from torchvision import transforms as trn
from torch.nn import functional as F
import os
# th architecture to use
arch = 'resnet18'
# create the network architecture
model = models.__dict__[arch](num_classes=365)
model_weight = '%s_places365.pth.tar' % arch
checkpoint = torch.load(model_weight, map_location=lambda storage, loc: storage) # model trained in GPU could be deployed in CPU machine like this!
state_dict = {str.replace(k,'module.',''): v for k,v in checkpoint['state_dict'].iteritems()} # the data parallel layer will add 'module' before each layer name
model.load_state_dict(state_dict)
model.eval()
model.cpu()
torch.save(model, 'whole_' + model_weight)
print 'save to ' + 'whole_' + model_weight