-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathtest_StreamMOS.py
More file actions
140 lines (118 loc) · 5.32 KB
/
test_StreamMOS.py
File metadata and controls
140 lines (118 loc) · 5.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import argparse
import pdb
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
import datasets
from utils.metric import MultiClassMetric
from models import *
# from visualization import draw_single_lidar_with_label
import tqdm
import importlib
import torch.backends.cudnn as cudnn
from tensorboardX import SummaryWriter as Logger
cudnn.benchmark = True
cudnn.enabled = True
def map(label, mapdict):
# put label from original values to xentropy
# or vice-versa, depending on dictionary values
# make learning map a lookup table
maxkey = 0
for key, data in mapdict.items():
if isinstance(data, list):
nel = len(data)
else:
nel = 1
if key > maxkey:
maxkey = key
# +100 hack making lut bigger just in case there are unknown labels
if nel > 1:
lut = np.zeros((maxkey + 100, nel), dtype=np.int32)
else:
lut = np.zeros((maxkey + 100), dtype=np.int32)
for key, data in mapdict.items():
try:
lut[key] = data
except IndexError:
print("Wrong key ", key)
# do the mapping
return lut[label]
def load_data_to_gpu(batch_dict):
for key, val in batch_dict.items():
if key in ['box_2d_label', 'box_2d_label_raw']:
for index, this_item in enumerate(val):
val[index]['boxes'] = torch.from_numpy(val[index]['boxes']).float().cuda()
val[index]['labels'] = torch.from_numpy(val[index]['labels']).long().cuda()
if isinstance(val, list):
continue
batch_dict[key] = val.float().cuda()
def val(epoch, model, val_loader, category_list, save_path, tb_logger, learning_map_inv, rank=0):
criterion_cate = MultiClassMetric(category_list)
# for id in range(val_loader.dataset.__len__()):
# val_loader.dataset.__getitem__(id)
model.eval()
f = open(os.path.join(save_path, 'record_{}.txt'.format(rank)), 'a')
query_embed_store = None
test_save_path = 'test'
with torch.no_grad():
for i, batch_dict in tqdm.tqdm(enumerate(val_loader)):
load_data_to_gpu(batch_dict)
pred_cls, pred_res_cls_0, pred_res_cls_1, pred_res_cls_2, query_embed_store = model.infer(batch_dict, i, query_embed_store)
pred_cls = F.softmax(pred_cls, dim=1)
pred_cls = pred_cls.mean(dim=0).permute(2, 1, 0).squeeze(0).contiguous()
# ------- vis & reshape data !!!! -------
valid_mask = batch_dict['valid_mask_list'][0].reshape(-1)
new_pred_cls = np.zeros((batch_dict['valid_mask_list'][0].shape[1]))
_, pred_cls = torch.max(pred_cls, dim=1)
pred_cls = pred_cls[:pred_cls.shape[0] - batch_dict['pad_length_list'][0][0]]
new_pred_cls[valid_mask] = pred_cls.cpu().numpy()
new_pred_cls = new_pred_cls.astype('uint32')
# ------- save -------
item_test_save_path = os.path.join(test_save_path, batch_dict['seq_id'][0], 'predictions')
if not os.path.exists(item_test_save_path):
os.makedirs(item_test_save_path)
pred_map = map(new_pred_cls, learning_map_inv)
pred_map.tofile(os.path.join(item_test_save_path, batch_dict['file_id'][0]+'.label'))
def main(args, config):
# parsing cfg
pGen, pDataset, pModel, pOpt = config.get_config()
prefix = pGen.name
save_path = os.path.join("experiments", prefix, args.tag)
model_prefix = os.path.join(save_path, "checkpoint")
tb_logger = Logger(save_path + "/test_tb")
# reset dist
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
# define dataloader
val_dataset = eval('datasets.{}.DataloadTest'.format(pDataset.Test.data_src))(pDataset.Test)
val_loader = DataLoader(val_dataset,
batch_size=1,
shuffle=False,
num_workers=pDataset.Val.num_workers,
pin_memory=True)
# define model
model = eval(pModel.prefix)(pModel)
model.cuda()
model.eval()
for epoch in range(args.start_epoch, args.end_epoch + 1, world_size):
if (epoch + rank) < (args.end_epoch + 1):
pretrain_model = os.path.join(model_prefix, '{}-model.pth'.format(epoch + rank))
model.load_state_dict(torch.load(pretrain_model, map_location='cpu'))
val(epoch + rank, model, val_loader, pGen.category_list, save_path, tb_logger, pDataset.Test.learning_map_inv, rank)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='lidar segmentation')
parser.add_argument('--config', help='config file path', type=str)
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument('--tag', help='config file path', type=str, default='base')
parser.add_argument('--start_epoch', type=int, default=40)
parser.add_argument('--end_epoch', type=int, default=41)
args = parser.parse_args()
config = importlib.import_module(args.config.replace('.py', '').replace('/', '.'))
main(args, config)