-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathvalidator.py
More file actions
26 lines (22 loc) · 816 Bytes
/
validator.py
File metadata and controls
26 lines (22 loc) · 816 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch.nn as nn
class Validator:
def __init__(self, model, dataload, epoch, device, batch_size):
self.model = model
self.dataload = dataload
self.epoch = epoch
self.device = device
self.batch_size = batch_size
self.criterion = nn.CrossEntropyLoss().to(self.device)
def __epoch(self, epoch):
self.model.eval()
loss_sum = 0
for ind, (inp, label) in enumerate(self.dataload):
inp = inp.float().to(self.device)
label = label.long().to(self.device)
out = self.model.forward(inp)
loss = self.criterion(out, label)
loss_sum += loss.item()
return {'val_loss': loss_sum/(ind+1)}
def eval(self):
val_loss = self.__epoch(self.epoch)
return val_loss