-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathretrieval.py
More file actions
99 lines (73 loc) · 3.79 KB
/
retrieval.py
File metadata and controls
99 lines (73 loc) · 3.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import sys
import torch
import numpy as np
import argparse
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'classification'))
from tqdm import tqdm
from classification.models.pointnet_cls import PointNet
from utils.data_utils.ModelNetDataLoader import ModelNetTestDataLoader
from utils.analyze_precision_recall import calc_macro_mean_average_precision
# CUDA_VISIBLE_DEVICES=0 python retrieval.py
def parse_args():
"""PARAMETERS"""
parser = argparse.ArgumentParser('training')
parser.add_argument('--num_class', default=40, type=int, choices=[10, 40], help='training on ModelNet10/40')
parser.add_argument('--num_point', type=int, default=1024, help='Point Number')
parser.add_argument('--num_delaunay', type=int, default=32, help='Delaunay')
parser.add_argument('--use_normals', action='store_true', default=False, help='use normals')
parser.add_argument('--use_delaunay', action='store_false', default=True, help='use delaunay neighbors')
parser.add_argument('--use_uniform_sample', action='store_true', default=False, help='use uniform sampiling')
parser.add_argument('--spl_batch_size', type=int, default=4, help='batch size in training')
return parser.parse_args()
def main():
args = parse_args()
'''HYPER PARAMETER'''
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
'''DATA LOADING'''
print('Load dataset ...')
# ------------------ ModelNet -----------------
data_path = 'data/modelnet40_normal_resampled/'
sample_path = 'log/sampling_2022-11-11-11-20-24/test_sampling/best_points/'
test_dataset = ModelNetTestDataLoader(root=data_path, sampled_path=sample_path, args=args)
testDataLoader_spl = torch.utils.data.DataLoader(test_dataset, batch_size=args.spl_batch_size, shuffle=False, num_workers=4)
# -------------------- Task network --------------------
# classification
model_cls = PointNet(args).to(device)
try:
checkpoint = torch.load('./classification/log/pointnet_2022-10-24-20-25_seed_390/best_model.pth')
model_cls.load_state_dict(checkpoint['model_state_dict'])
print('Load pretrained classification model')
except:
print('Error: No existing classification model')
exit()
ret_vec_ls, label_ls = [], []
with torch.no_grad():
model_cls.eval()
total_correct = 0
total_seen = 0
class_acc = np.zeros((args.num_class, 3)).astype(float)
for batch_id, (points, target) in tqdm(enumerate(testDataLoader_spl), total=len(testDataLoader_spl)):
points, target = points.float().to(device), target.long().to(device)
# classification
pred, ret_vec = model_cls(points.transpose(1, 2), retrieval=True)
pred_choice = pred.max(1)[1]
correct = pred_choice.eq(target).sum()
total_correct += correct.item()
total_seen += target.size()[0]
for cat in np.unique(target.cpu()):
classacc = pred_choice[target == cat].eq(target[target == cat]).sum()
class_acc[cat, 0] += classacc.item()
class_acc[cat, 1] += target[target == cat].size()[0]
ret_vec_ls.append(ret_vec.cpu().detach().numpy())
label_ls.append(np.expand_dims(target.cpu().detach().numpy(), axis=-1))
instance_acc = total_correct / float(total_seen)
class_acc[:, 2] = class_acc[:, 0] / class_acc[:, 1]
class_acc = float(np.mean(class_acc[:, 2]))
print('Test Instance Accuracy: %f, Class Accuracy: %f' % (instance_acc, class_acc))
ret_vecs = np.vstack(ret_vec_ls)
labels = np.vstack(label_ls)
res_macro = calc_macro_mean_average_precision(ret_vecs, labels)
print('Retrieval mAP: %f' % res_macro)
if __name__ == '__main__':
main()