-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdemo.py
More file actions
63 lines (56 loc) · 2.16 KB
/
demo.py
File metadata and controls
63 lines (56 loc) · 2.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
import torch.nn as nn
import pdb
from PIL import Image
import matplotlib.pyplot as plt
from DataProvider import DataProviderUtil
from model import featurenet,Block
import time
import numpy as np
import postprocessing, LossFuncs,pix_2_spix
import os
from imguidedfilter import imguidedfilter
from torchvision.transforms import functional as F
import sys
class demo():
def __init__(self,image_name='table.png'):
self.net=featurenet(Block,planes=[64,128,256]).cuda()
self.pca_util=postprocessing.PCA_util()
self.image_name=image_name
self.demo_im_root='./demo_im/'
if not os.path.exists('./Results_demo'): os.mkdir('./Results_demo')
def start_demo(self,save_results=False):
_=self.load_saved_model()
image=Image.open(self.demo_im_root+self.image_name).convert('RGB')
image=F.to_tensor(np.array(image))
image=image[None,:,:,:].cuda()
outFeats=self.net(image)
if save_results:
self.save_figures(outFeats,image)
def load_saved_model(self):
try:
checkpoint = torch.load('./saved_model/LatestSavednet.pth')
self.net.load_state_dict(checkpoint['model_state_dict'])
ep=checkpoint['epoch']+1
print('loaded Latest Model and starting epoch:', ep)
except Exception as e:
print('No saved Model found... starting Training')
ep=0
return ep
def save_figures(self,outFeats,image):
simp_gui,simp_ung,_=self.pca_util.find_guidedfiltered_dom_feats(outFeats[0].permute(1,2,0).cpu().detach().numpy(),image[0].cpu().permute(1,2,0).numpy())
plt.figure(figsize=(18,9))
plt.subplot(1,2,1)
plt.imshow(image[0].cpu().permute(1,2,0).numpy())
plt.axis('off')
plt.title('Original Image')
plt.subplot(1,2,2)
plt.imshow(simp_gui)
plt.axis('off')
plt.title('Guided Filtered Dominant Features')
plt.savefig('./Results_demo/'+self.image_name)
print('Result Saved...')
if __name__ == '__main__':
image_name=sys.argv[1]
inst=demo(image_name=image_name)
inst.start_demo(save_results=True)