forked from ytZhang99/CF-Net
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest.py
More file actions
63 lines (50 loc) · 2.33 KB
/
test.py
File metadata and controls
63 lines (50 loc) · 2.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
import cv2
import time
import torch
import torch.nn
import numpy as np
import torchvision.transforms as transforms
from tqdm import trange
from model import CFNet
from option import args
class Test:
def __init__(self):
self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5])])
self.test_dir_pre = args.dir_test
self.over_imgs = os.listdir(self.test_dir_pre + 'lr_over/')
self.over_imgs.sort()
self.under_imgs = os.listdir(self.test_dir_pre + 'lr_under/')
self.under_imgs.sort()
assert len(self.over_imgs) == len(self.under_imgs)
self.num_imgs = len(self.over_imgs)
self.model = CFNet().cuda()
self.state = torch.load(args.model_path + args.model)
self.model.load_state_dict(self.state['model'])
self.test_time = []
def test(self):
self.model.eval()
with torch.no_grad():
for idx in trange(self.num_imgs):
img1 = cv2.imread(self.test_dir_pre + 'lr_over/' + self.over_imgs[idx])
img1 = torch.unsqueeze(self.transform(img1), 0)
img2 = cv2.imread(self.test_dir_pre + 'lr_under/' + self.under_imgs[idx])
img2 = torch.unsqueeze(self.transform(img2), 0)
assert img1.shape == img2.shape
save_name = os.path.splitext(os.path.split(self.over_imgs[idx])[1])[0]
img1 = img1.cuda()
img2 = img2.cuda()
torch.cuda.synchronize()
start_time = time.time()
sr_over, sr_under = self.model(img1, img2)
img_fused = 0.5 * sr_over[-1] + 0.5 * sr_under[-1]
img_fused = img_fused.squeeze(0)
torch.cuda.synchronize()
end_time = time.time()
self.test_time.append(end_time - start_time)
img_fused = img_fused.cpu().numpy()
img_fused = np.transpose(img_fused, (1, 2, 0))
img_fused = img_fused.astype(np.uint8)
cv2.imwrite(os.path.join(args.save_dir, str(save_name) + args.ext), img_fused)
print('The average testing time is {:.4f} s.'.format(np.mean(self.test_time)))