-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest.py
More file actions
56 lines (44 loc) · 2.01 KB
/
test.py
File metadata and controls
56 lines (44 loc) · 2.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader
from sklearn.metrics import jaccard_score
from utils_data import Dataset_Seg
from utils_model import UNet
from pathlib import Path
from utils import set_seed
def compute_iou_on_test(model, test_loader, device, threshold=0.5):
model.eval()
preds = []
targets = []
with torch.no_grad():
for batch in tqdm(test_loader, desc="Evaluating IoU", unit="img"):
images, true_masks = batch
images = images.to(device=device, dtype=torch.float32, memory_format=torch.channels_last)
true_masks = true_masks.to(device=device, dtype=torch.long)
output = model(images)
if model.n_classes == 1:
probs = torch.sigmoid(output.squeeze(1))
pred_masks = (probs > threshold).long()
preds.extend(pred_masks.cpu().numpy().flatten())
targets.extend(true_masks.cpu().numpy().flatten())
else:
probs = torch.softmax(output, dim=1)
pred_masks = probs.argmax(dim=1)
preds.extend(pred_masks.cpu().numpy().flatten())
targets.extend(true_masks.cpu().numpy().flatten())
iou = jaccard_score(targets, preds, average='binary' if model.n_classes == 1 else 'macro')
print(f"IoU: {iou:.4f}")
return iou
if __name__ == '__main__':
path_to_data = '/content/drive/MyDrive/archive'
set_seed()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Model
model = UNet(n_channels=3, n_classes=1, bilinear=False)
model.to(device=device)
model.load_state_dict(torch.load(Path(path_to_data) / 'checkpoint_best.pth'), map_location=device)
test_data = Dataset_Seg(path_to_data, 'test')
loader_args = dict(batch_size=1, pin_memory=True)
test_loader = DataLoader(test_data, shuffle=False, drop_last=False, **loader_args)
iou_score = compute_iou_on_test(model, test_loader, device)
print("Final IoU on test set:", iou_score)