-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.py
More file actions
113 lines (97 loc) · 3.9 KB
/
model.py
File metadata and controls
113 lines (97 loc) · 3.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import numpy as np
import torch
import torch.nn as nn
from torch_geometric.nn.models import GAT
from torch_geometric.nn import EdgeConv
from torch_geometric.nn import AttentionalAggregation
# from sklearn.base import BaseEstimator, ClassifierMixin
# from sklearn.model_selection import GridSearchCV
# from sklearn.metrics import make_scorer, accuracy_score
from torch_geometric.utils import to_undirected
from torch_cluster import knn_graph
from torch_geometric.nn import GraphNorm
class GATWithPooling(nn.Module):
def __init__(self, in_channels, hidden_channels, num_layers, out_channels, dropout=0.5, act='elu',aggr='sum', act_first=False, norm=None, v2=False, jk=None):
super(GATWithPooling, self).__init__()
self.k = num_layers
start_width = 2 * hidden_channels
middle_width = 3 * hidden_channels// 2
self.inputnet = nn.Sequential(
nn.Linear(in_channels, hidden_channels* 2),
nn.ELU(),
nn.Linear(hidden_channels * 2, hidden_channels * 2),
nn.ELU(),
nn.Linear(hidden_channels * 2, hidden_channels),
nn.ELU()
)
convnn1 = nn.Sequential(nn.Linear(start_width, middle_width),
nn.ELU(),
nn.Linear(middle_width, hidden_channels),
nn.ELU()
)
convnn2 = nn.Sequential(nn.Linear(start_width*2 , middle_width),
nn.ELU(),
nn.Linear(middle_width, hidden_channels),
nn.ELU()
)
self.edgeconv1 = EdgeConv(nn=convnn1, aggr=aggr)
self.edgeconv2 = EdgeConv(nn=convnn2, aggr=aggr)
self.gat1 = GAT(
in_channels=hidden_channels*2,
hidden_channels=hidden_channels*2,
num_layers=num_layers,
out_channels=hidden_channels*2,
dropout=dropout,
act=act,
act_first=act_first,
norm=norm,
v2=v2,
jk=jk
)
self.gat2 = GAT(
in_channels=hidden_channels*4,
hidden_channels=hidden_channels*2,
num_layers=num_layers,
out_channels=hidden_channels*2,
dropout=dropout,
act=act,
act_first=act_first,
norm=norm,
v2=v2,
jk=jk
)
self.gat3 = GAT(
in_channels=hidden_channels*6,
hidden_channels=hidden_channels*2,
num_layers=num_layers,
out_channels=hidden_channels,
dropout=dropout,
act=act,
act_first=act_first,
norm=norm,
v2=v2,
jk=jk
)
self.pool = AttentionalAggregation(gate_nn=nn.Linear(hidden_channels+2, 1))
self.fc = nn.Linear(hidden_channels+2, out_channels)
self.fc_trk = nn.Linear(hidden_channels*2, out_channels)
def forward(self, x, edge_index, batch):
orx = x.clone()
x= self.inputnet(x)
orig_x = x.clone()
edge_index = to_undirected(knn_graph(x, self.k, batch, loop=True, flow=self.edgeconv1.flow))
x = self.edgeconv1(x, edge_index)
x = torch.cat([x, orig_x], dim=-1)
trk = x.clone()
res1=x.clone()
x = self.gat1(x, edge_index)
x = torch.cat([x, res1], dim=-1)
res2 = x.clone()
x = self.gat2(x, edge_index)
x = torch.cat([x, res2], dim=-1)
x = self.gat3(x, edge_index)
x = torch.cat([x, orx[:,2].unsqueeze(-1),orx[:,3].unsqueeze(-1)], dim=-1)
x = self.pool(x, batch) # Pooling node features to get graph-level features
trk = self.fc_trk(trk)
x = self.fc(x)
return x.squeeze(), trk.squeeze()