-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patheval_utils.py
More file actions
71 lines (60 loc) · 2.22 KB
/
eval_utils.py
File metadata and controls
71 lines (60 loc) · 2.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
from config import device
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
def train(model, data_loader, val_loader, model_path, epochs=100):
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)
loss_fn = torch.nn.BCELoss()
model.train()
best_val_loss = float("inf")
for _ in range(epochs):
model.train()
epoch_loss = 0
for data in data_loader:
data = data.to(device)
optimizer.zero_grad()
out = model(data.x, data.edge_index, data.batch)
loss = loss_fn(out, data.y.float())
loss.backward()
optimizer.step()
epoch_loss += loss.item()
epoch_loss /= len(data_loader)
val_loss = val(model, val_loader)
if val_loss < best_val_loss:
best_val_loss = val_loss
torch.save(model.state_dict(), model_path)
print(f"Epoch: {_}, Train Loss: {epoch_loss}, Best: {best_val_loss}, Val Loss: {val_loss}", end="\r")
model = torch.load(model_path)
return best_val_loss
def val(model, data_loader):
model.eval()
loss = 0
loss_fn = torch.nn.BCELoss()
for data in data_loader:
data = data.to(device)
with torch.no_grad():
out = model(data.x, data.edge_index, data.batch)
loss += loss_fn(out, data.y.float()).item()
loss /= len(data_loader)
return loss
def test(model, data_loader):
model.eval()
y_true = []
y_pred = []
for data in data_loader:
data = data.to(device)
out = model(data.x, data.edge_index, data.batch)
y_true.append(data.y)
y_pred.append(
out > 0.5
) # threshold because of sigmoid activation at last layer
y_true = torch.cat(y_true, dim=0)
y_pred = torch.cat(y_pred, dim=0)
y_true = y_true.cpu().numpy()
y_pred = y_pred.cpu().numpy()
accuracy = accuracy_score(y_true, y_pred)
precision = precision_score(y_true, y_pred)
recall = recall_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred)
# create a mask for the idxs of the correct predictions
correct_mask = y_true == y_pred
return accuracy, precision, recall, f1, correct_mask