-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_vis.py
More file actions
135 lines (110 loc) · 4.27 KB
/
test_vis.py
File metadata and controls
135 lines (110 loc) · 4.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import os
import time
import torch
import json
import numpy as np
import time
from copy import deepcopy
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor, Normalize, Compose, ColorJitter
from net.loss import *
from net.network_sn_101 import ACSPNet
from config import Config
from dataloader.loader import *
from util.functions import parse_det_offset
from eval_city.eval_script.eval_demo import validate
from sys import exit
from util import draw_bbox
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
config = Config()
config.test_path = './data/citypersons'
config.size_test = (1280, 2560)
config.init_lr = 2e-4
config.offset = True
config.val = True
config.val_frequency = 1
config.teacher = True
config.print_conf()
# dataset
testtransform = Compose([ToTensor(), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
testdataset = CityPersons(path=config.test_path, type='val', config=config, transform=testtransform, preloaded=True)
testloader = DataLoader(testdataset, batch_size=1)
# net
print('Net...')
net = ACSPNet().cuda()
# position
center = cls_pos().cuda()
height = reg_pos().cuda()
offset = offset_pos().cuda()
teacher_dict = net.state_dict()
def get_gtbox():
anno = draw_bbox.get_anno('./eval_city/val_gt.json')
anns = anno.anns
anns_id = 1
gt_boxes = []
for img_id in range(1, 501):
img_bbox = []
# img_vbox = []
while anns_id < len(anns) and anns[anns_id]['image_id'] == img_id:
ignore = anns[anns_id]["ignore"]
if ignore == 1:
anns_id +=1
continue
print(anns_id)
bbox =anns[anns_id]["bbox"]
bbox[2] = bbox[2] + bbox[0]
bbox[3] = bbox[3]+bbox[1]
img_bbox.append(bbox)
anns_id += 1
gt_boxes.append(img_bbox)
return gt_boxes
def val(r, name, log=None):
base_path = "/mnt/D0D8D177D8D15C72/cityperson_visualize_heatmap/%s/"%(name[name.find('V') : name.find('_')])
if not os.path.exists(base_path):
os.mkdir(base_path)
net.eval()
# load the model here!!!
teacher_dict = torch.load(name)
net.load_state_dict(teacher_dict)
# print(net)
print('Perform validation...')
res = []
t3 = time.time()
# gtboxes = get_gtbox()
for i, data in enumerate(testloader, 0):
# img_gt_boxs = gtboxes[i]
inputs = data.cuda() # torch.Size([1, 3, 1024, 2048])
with torch.no_grad():
pos, height, offset = net(inputs)
# add visulize by weida 2020-11-14------------------------------------------------------------------------------
import cv2
original_img = data[0].permute(1, 2, 0).cpu().numpy() # torch.Size([1024, 2048, 3])
original_img = cv2.normalize(original_img, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
# plot boxes----------------------------------
dtboxes = parse_det_offset(r, pos.cpu().numpy(), height.cpu().numpy(), offset.cpu().numpy(), config.size_test,
score=0.3, down=4, nms_thresh=0.5)
# original_img = draw_bbox.draw_boxes(original_img, img_gt_boxs, line_thick=2, line_color='white')
box_ori_img = draw_bbox.draw_boxes(original_img, dtboxes, line_thick=2, line_color='green')
ori_basename = base_path + str(i) + "_ori_bbox" + ".png"
# cv2.imwrite(hm_basename, box_ori_img)
cv2.imshow(ori_basename, box_ori_img)
cv2.waitKey(0)
cv2.destroyAllWindows()
# print("end %d visualize pos" % i)
# end add visulize ------------------------------------------------------------------------------
# or Val your own model
version = 'V0_resnetv2sn50_1centergaussmap_originaladdsenetinresnet_640_1280_1gpuper2img_lr0.0002'
log_floder = './models/' + version + '/validation_result_log/'
log_file = log_floder + version + time.strftime('val_log_%Y%m%d_%H%M%S', time.localtime(time.time())) + '.log'
if not os.path.exists(log_floder):
os.mkdir(log_floder)
log = open(log_file, 'w')
for i in range(78, 150):
name = './models/' + version + '/ckpt/ACSP_{0}.pth.tea'.format(i)
if not os.path.exists(name):
continue;
val(0.36, name, log)
# name_1 = './models/ACSP(Smooth L1).pth.tea'
# name_2 = './models/ACSP(Vanilla L1).pth.tea'
# val(0.40, name_2)
# val(0.36, name_2)