-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodels_graph.py
More file actions
89 lines (76 loc) · 2.82 KB
/
models_graph.py
File metadata and controls
89 lines (76 loc) · 2.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GINConv, GATConv
from torch_geometric.nn import global_mean_pool
from torch_geometric.data import DataLoader, Data
from config import args
def load_model(name="GCN", in_dim=14, hidden_dim=64, out_dim=1):
if name == "GCN":
return GCN(in_dim, hidden_dim, out_dim)
elif name == "GAT":
return GAT(in_dim, hidden_dim, out_dim)
elif name == "GIN":
return GIN(in_dim, hidden_dim, out_dim)
else:
raise NotImplementedError(f"Model {name} is not implemented.")
class GCN(torch.nn.Module):
def __init__(self, in_dim=None, hidden_dim=None, out_dim=None):
super().__init__()
self.conv1 = GCNConv(in_dim, hidden_dim)
self.conv2 = GCNConv(hidden_dim, hidden_dim)
self.lin = torch.nn.Linear(hidden_dim, out_dim)
def forward(self, x, edge_index, batch=None):
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = global_mean_pool(x, batch)
x = self.lin(x)
# sigmoid
x = F.sigmoid(x)
x = x.squeeze(1)
return x
class GAT(torch.nn.Module):
def __init__(self, in_dim=None, hidden_dim=None, out_dim=None):
super().__init__()
self.conv1 = GATConv(in_dim, hidden_dim)
self.conv2 = GATConv(hidden_dim, hidden_dim)
self.lin = torch.nn.Linear(hidden_dim, out_dim)
def forward(self, x, edge_index, batch=None):
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = global_mean_pool(x, batch)
x = self.lin(x)
# sigmoid
x = F.sigmoid(x)
x = x.squeeze(1)
return x
class GIN(torch.nn.Module):
def __init__(self, in_dim=None, hidden_dim=None, out_dim=None):
super(GIN, self).__init__()
self.mlp1 = torch.nn.Linear(in_dim, hidden_dim)
self.conv1 = GINConv(self.mlp1)
self.mlp2 = torch.nn.Linear(hidden_dim, hidden_dim)
self.conv2 = GINConv(self.mlp2)
self.mlp3 = torch.nn.Linear(hidden_dim, hidden_dim)
self.conv3 = GINConv(self.mlp3)
self.lin = torch.nn.Linear(hidden_dim, out_dim)
def forward(self, x, edge_index, batch=None):
x = self.conv1(x, edge_index)
x = x.relu()
x = self.conv2(x, edge_index)
x = x.relu()
x = self.conv3(x, edge_index)
x = x.relu()
x = global_mean_pool(x, batch)
# x = F.dropout(x, p=0.5, training=self.training)
x = self.lin(x)
x = F.sigmoid(x)
x = x.squeeze(1)
return x