-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
executable file
·95 lines (70 loc) · 3.96 KB
/
train.py
File metadata and controls
executable file
·95 lines (70 loc) · 3.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import argparse
import numpy as np
import random
import torch
from tqdm import tqdm
from torch_geometric.loader import DataLoader
from sklearn.metrics import mean_absolute_error, root_mean_squared_error
from src.dataset import GraphMatchingDataset
from src.subproblem_dataset import GraphMatchingSubproblemDataset
from src.model import LinkGNN
from src.utils import run_inference, training_step_link, validation_step_link
from src.utils import normalized_mae, exact_hit_rate
def main(args):
train_dataset = GraphMatchingSubproblemDataset(name=args.data, root=args.root, num_pairs=args.train_pairs, num_instances_per_pair=args.instances_per_pair, split='train')
val_dataset = GraphMatchingSubproblemDataset(name=args.data, root=args.root, num_pairs=100, num_instances_per_pair=args.instances_per_pair, split='val')
val_dataset_inf = GraphMatchingDataset(name=args.data, root=args.root, num_pairs=2000, split='val')
num_node_labels = train_dataset[0].x.shape[1]
num_edge_labels = train_dataset[0].edge_attr.shape[1]
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=1024, shuffle=False, pin_memory=True)
model = LinkGNN(num_node_labels, num_edge_labels, 128, args.layers, args.node_cost, args.edge_cost)
model = model.to(args.device)
model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
best_nmae = 1e10
for epoch in tqdm(range(args.epochs), ncols=64):
epoch_loss, epoch_acc = training_step_link(model, train_loader, optimizer, args)
epoch_val_loss, epoch_val_acc = validation_step_link(model, val_loader, args)
costs, true_costs = run_inference(model, val_dataset_inf, k=args.k, batch_size=64, disable_tqdm=True)
nmae = normalized_mae(costs, true_costs)
ehr = exact_hit_rate(costs, true_costs)
print(f'train-loss: {epoch_loss:.5f}, nMAE: {nmae:.5f}, EHR: {ehr:.5f}')
if args.log and args.save_ckp:
with open(args.save_ckp.rsplit('.', 1)[0]+"_train.log", "a") as f:
f.write(f'{epoch_loss:.6f} {epoch_acc:.5f} {epoch_val_loss:.6f} {epoch_val_acc:.5f} {nmae:.5f} {ehr:.5f}\n')
if args.save_ckp is not None and (nmae < best_nmae):
best_nmae = nmae
torch.save(model.to('cpu').state_dict(), args.save_ckp)
model.to(args.device)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--root', type=str, default='data/')
parser.add_argument('--data', type=str, default=None)
parser.add_argument('--epochs', type=int, default=30)
parser.add_argument('--layers', type=int, default=5)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--weight_decay', type=float, default=0.0)
parser.add_argument('--max_train_steps', type=float, default=1.0)
parser.add_argument('--k', type=int, default=32)
parser.add_argument('--train_pairs', type=int, default=None)
parser.add_argument('--instances_per_pair', type=int, default=40)
parser.add_argument('--node_cost', type=float, default=1.0)
parser.add_argument('--edge_cost', type=float, default=1.0)
parser.add_argument('--save_ckp', type=str, default=None)
parser.add_argument('--log', action='store_true')
parser.add_argument('--nocuda', action='store_true')
args = parser.parse_args()
args.device = torch.device("cuda" if (torch.cuda.is_available() and (not args.nocuda)) else "cpu")
print(args)
if args.log and args.save_ckp:
with open(args.save_ckp.rsplit('.', 1)[0]+"_train.log", "w") as f:
f.write(str(args)+'\n')
np.random.seed(args.seed)
random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.set_printoptions(linewidth=200)
torch.set_printoptions(edgeitems=20)
main(args)