-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest.py
More file actions
114 lines (84 loc) · 2.86 KB
/
test.py
File metadata and controls
114 lines (84 loc) · 2.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import cv2
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from unet.unet_model import UNet
# =========================
# Dataset (image and labels)
# =========================
class TestDataset(Dataset):
def __init__(self, img_dir, lab_dir):
self.img_dir = img_dir
self.lab_dir = lab_dir
self.files = sorted(os.listdir(img_dir))
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
fname = self.files[idx]
img = cv2.imread(os.path.join(self.img_dir, fname), 0)
img = torch.from_numpy(img).float().unsqueeze(0) / 255.0
lab = cv2.imread(os.path.join(self.lab_dir, fname), 0)
lab = torch.from_numpy(lab).float().unsqueeze(0) / 255.0
lab = (lab > 0.5).float()
return img, lab, fname
# =========================
# Dice score
# =========================
def dice_score(pred, target):
smooth = 1e-6
pred = pred.reshape(-1)
target = target.reshape(-1)
intersection = (pred * target).sum()
return (2. * intersection + smooth) / (pred.sum() + target.sum() + smooth)
# =========================
# IoU score
# =========================
def iou_score(pred, target):
smooth = 1e-6
pred = pred.reshape(-1)
target = target.reshape(-1)
intersection = (pred * target).sum()
union = pred.sum() + target.sum() - intersection
return (intersection + smooth) / (union + smooth)
# =========================
# Setup
# =========================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Testing on:", device)
model = UNet(n_channels=1, n_classes=1)
model.load_state_dict(torch.load("unet_model.pth", map_location=device))
model.to(device)
model.eval()
dataset = TestDataset("data/test_img", "data/test_lab")
loader = DataLoader(dataset, batch_size=1, shuffle=False)
os.makedirs("predictions", exist_ok=True)
# =========================
# Inference loop
# =========================
total_dice = 0
total_iou = 0
count = 0
with torch.no_grad():
for imgs, labs, fnames in tqdm(loader):
imgs = imgs.to(device)
labs = labs.to(device)
preds = model(imgs)
preds = torch.sigmoid(preds)
preds = (preds > 0.5).float()
pred_mask = preds.cpu().numpy()[0, 0]
lab_mask = labs.cpu().numpy()[0, 0]
# Calculate metrics
dice = dice_score(preds, labs).item()
iou = iou_score(preds, labs).item()
total_dice += dice
total_iou += iou
# save mask
save_path = os.path.join("predictions", fnames[0])
cv2.imwrite(save_path, pred_mask * 255)
count += 1
print("Saved predictions to ./predictions/")
print(f"\nPerformance Metrics:")
print(f"Average Dice Score: {total_dice / count:.4f}")
print(f"Average IoU Score: {total_iou / count:.4f}")