-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbinary_sak_check.py
More file actions
32 lines (28 loc) · 1.02 KB
/
binary_sak_check.py
File metadata and controls
32 lines (28 loc) · 1.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
from pathlib import Path
from torch.utils.data import DataLoader
from Dataset import KTDataset
from model_small import SmallCnn
num_classes=1
base_model=SmallCnn(num_classes)
base_model.load_state_dict(torch.load("weights/KT_binary_sak_smalla_29.pth", map_location=torch.device('cpu')))
def predict(model, test_loader):
with torch.no_grad():
logits = []
for inputs in test_loader:
model.eval()
outputs = model(inputs).cpu()
logits.append(outputs)
# print(logits)
probs = torch.sigmoid(torch.cat(logits)).numpy()
probs_pro = [1 if i>=0.5 else 0 for i in probs]
return probs, probs_pro
def predict_picture():
TEST_DIR = Path('test_img')
test_files=list(TEST_DIR.rglob('*.jpg'))
for file in test_files:
test_dataset = KTDataset([file], mode="test")
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=64)
probs, probs_pro = predict(base_model, test_loader)
print(file, probs, probs_pro)
predict_picture()