-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_sucode.py
More file actions
157 lines (121 loc) · 5.48 KB
/
test_sucode.py
File metadata and controls
157 lines (121 loc) · 5.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import argparse
import cv2
import glob
import os
from tqdm import tqdm
import torch
from yaml import load
import pdb
import numpy as np
import matplotlib as m
import matplotlib.pyplot as plt
from PIL import Image
import pyiqa
from basicsr.utils import img2tensor, tensor2img, imwrite
from basicsr.archs.my_sucode_arch import SUCode
from basicsr.utils.download_util import load_file_from_url
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#eval metrics
metric_funcs = {}
metric_funcs['psnr'] = pyiqa.create_metric('psnr', device=device, crop_border=4)
metric_funcs['ssim'] = pyiqa.create_metric('ssim', device=device, crop_border=4)
metric_funcs['lpips'] = pyiqa.create_metric('lpips', device=device)
def main(args):
"""Inference demo for FeMaSR
"""
metric_results = {'psnr': 0, 'ssim': 0, 'lpips': 0}
weight_path = args.weight
# set up the model
model_params = torch.load(weight_path)['params']
codebook_dim = np.array([v.size() for k,v in model_params.items() if 'quantize_group' in k])
codebook_dim_list = []
for k in codebook_dim:
temp = k.tolist()
temp.insert(0,32)
codebook_dim_list.append(temp)
print('Codebook dimension list:', codebook_dim_list)
recon_model = SUCode(codebook_params=codebook_dim_list, LQ_stage=True, AdaCode_stage=True, Coder_stage=True,
batch_size=4, weight_softmax=False).to(device)
recon_model.load_state_dict(model_params, strict=False)
recon_model.eval()
from torchinfo import summary
summary(recon_model, ((1, 3, 128, 128)))
os.makedirs(args.output, exist_ok=True)
if os.path.isfile(args.input):
paths = [args.input]
else:
paths = sorted(glob.glob(os.path.join(args.input, '*.*')))
pbar = tqdm(total=len(paths), unit='image')
for idx, path in enumerate(paths):
# try:
img_name = os.path.basename(path)
pbar.set_description(f'Test {img_name}')
# recon
img_HR = cv2.imread(path, cv2.IMREAD_UNCHANGED)
img_HR_tensor = img2tensor(img_HR).to(device) / 255.
from torchvision.transforms import Resize, InterpolationMode
torch_resize = Resize([256,256], interpolation=InterpolationMode.BICUBIC)
# torch_resize = Resize([256,256], interpolation=InterpolationMode.BILINEAR)
img_HR_tensor = torch_resize(img_HR_tensor) # B x C x H x W
img_HR_tensor = img_HR_tensor.unsqueeze(0)
max_size = args.max_size ** 2
h, w = img_HR_tensor.shape[2:]
# if h * w < max_size:
# output_HR = recon_model.test(img_HR_tensor, vis_weight=args.vis_weight)
# else:
# output_HR = recon_model.test_tile(img_HR_tensor, vis_weight=args.vis_weight)
# output_HR = recon_model.test(img_HR_tensor, vis_weight=args.vis_weight)
output_HR = recon_model.test(img_HR_tensor)
if args.vis_weight:
weight_map = output_HR[1]
vis_weight(weight_map, os.path.join(args.output, 'weight_map', img_name))
output = output_HR[0]
else:
output = output_HR
output = output.clamp(0, 1) # B x C x H x W
output_img = tensor2img(output)
imwrite(output_img, os.path.join(args.output, f'{img_name}'))
for name in metric_funcs.keys():
metric_results[name] += metric_funcs[name](img_HR_tensor, output).item()
pbar.update(1)
# except:
# print(path, ' fails.')
pbar.close()
for name in metric_results.keys():
metric_results[name] /= len(paths)
print('Result for {}'.format(args.weight))
print(metric_results)
def vis_weight(weight, save_path):
# weight: B x n x 1 x H x W
weight = weight.cpu().numpy()
# normalize weights
# norm_weight = weight
norm_weight = (weight - weight.mean()) / weight.std() / 2
norm_weight = np.abs(norm_weight)
norm_weight *= 255
norm_weight = np.clip(norm_weight, 0, 255)
norm_weight = norm_weight.astype(np.uint8)
# visualize
display_grid = np.zeros((weight.shape[3], (weight.shape[4]+1)*weight.shape[1]-1))
for img_id in range(len(norm_weight)):
for c in range(norm_weight.shape[1]):
display_grid[:, c*weight.shape[4]+c:(c+1)*weight.shape[4]+c] = norm_weight[img_id, c, 0, :, :]
# weight_path = save_path.split('.')[0] + '_{}.png'.format(str(c))
# Image.fromarray(norm_weight[img_id, c, 0, :, :]).save(weight_path)
plt.figure(figsize=(30,150))
plt.axis('off')
plt.imshow(display_grid)
plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
plt.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input', type=str, default='./dataset/valid/images/', help='Input image or folder')
parser.add_argument('-w', '--weight', type=str, default='./models/net_sucode_g_best_.pth', help='path for model weights')
parser.add_argument('-o', '--output', type=str, default='./output', help='Output folder')
parser.add_argument('--suffix', type=str, default='stage2', help='Suffix of the restored image')
parser.add_argument('--max_size', type=int, default=256, help='Max image size for whole image inference, otherwise use tiled_test')
parser.add_argument('--vis_weight', action='store_true', help='visualize weight map')
args = parser.parse_args()
if args.vis_weight:
os.makedirs(os.path.join(args.output, 'weight_map'), exist_ok=True)
main(args)