-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.py
More file actions
66 lines (56 loc) · 2.21 KB
/
model.py
File metadata and controls
66 lines (56 loc) · 2.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os
import cv2
import torch
from torch import nn
import torchvision.models as models
import transform
from network import *
net = SResNet
num_classes = 2
ckpt_path = "Net.pth"
transform_method_origin = 1
threshold = 0.5
class model:
def __init__(self, device=torch.device("cpu")):
self.checkpoint = ckpt_path
self.device = device
def load(self, dir_path):
"""
load the model and weights.
dir_path is a string for internal use only - do not remove it.
all other paths should only contain the file name, these paths must be
concatenated with dir_path, for example: os.path.join(dir_path, filename).
make sure these files are in the same directory as the model.py file.
:param dir_path: path to the submission directory (for internal use only).
:return:
"""
self.model = net(num_classes=num_classes)
# join paths
checkpoint_path = os.path.join(dir_path, self.checkpoint)
self.model.load_state_dict(torch.load(checkpoint_path, map_location=self.device))
self.model.to(self.device)
self.model.eval()
def predict(self, input_image):
"""
perform the prediction given an image.
input_image is a ndarray read using cv2.imread(path_to_image, 1).
note that the order of the three channels of the input_image read by cv2 is BGR.
:param input_image: the input image to the model.
:return: an int value indicating the class for the input image
"""
# image transform
image = transform.transform_method(method=transform_method_origin)(input_image)
# image dimension expansion (do not change)
image = image.unsqueeze(0) # (3, x, x) -> (1, 3, x, x)
# image to device
image = image.to(self.device, torch.float)
with torch.no_grad():
score = self.model(image)
if num_classes == 1:
pr = torch.sigmoid(score).detach().cpu().item()
pred_class = int(pr >= threshold)
elif num_classes == 2:
_, pred_class = torch.max(score, dim=1)
pred_class = pred_class.detach().cpu()
pred_class = int(pred_class)
return pred_class