-
Notifications
You must be signed in to change notification settings - Fork 22
Expand file tree
/
Copy pathtest_foliage.py
More file actions
39 lines (32 loc) · 1016 Bytes
/
test_foliage.py
File metadata and controls
39 lines (32 loc) · 1016 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
import sys
import os
import h5py
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(BASE_DIR, 'models'))
from PreProcess import FoliageFilter
model = FoliageFilter.load_from_checkpoint("./fckpt/epoch=28-val_loss=0.17-val_acc=0.951.ckpt")
print(model.hparams)
model = torch.nn.DataParallel(model)
model = model.cuda()
model.eval()
f = h5py.File('data/Foliage_Segmentation/tree_labeled_test.hdf5','r')
ds = f['points'][:]
fns = f['names'][:]
#centroid = f['centroid'][:]
#scales = f['scale'][:]
f.close()
preds_if = []
for i in range(len(ds)):
pxyz = torch.from_numpy(ds[i]).float()
pxyz = torch.unsqueeze(pxyz,dim=0)
logits_if = model(pxyz.cuda())
pred_if = torch.argmax(logits_if, dim=1)
pred_if = torch.squeeze(pred_if)
preds_if.append(pred_if.cpu().numpy())
with h5py.File('./foliage_seg.hdf5', 'w') as f:
f['points'] = ds
f['pred_isfoliage'] = preds_if
f['names'] = fns
#f['centroid'] = centroid
#f['scale'] = scales