-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.py
More file actions
282 lines (218 loc) · 9.37 KB
/
model.py
File metadata and controls
282 lines (218 loc) · 9.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
import torch
import torch.nn as nn
from torch_geometric.nn import SAGPooling, GINConv, SAGEConv
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
def normalize_edges(num_nodes, edge_index, edge_weight):
"""Compute D^{-1/2} normalization for graph convolution"""
if edge_index.numel() == 0:
return edge_weight
src, dst = edge_index
deg = torch.zeros(num_nodes, device=edge_index.device, dtype=edge_weight.dtype)
deg = deg.scatter_add_(0, src, edge_weight)
deg = deg.clamp(min=1e-12)
d_inv_sqrt = deg.pow(-0.5)
norm = edge_weight * d_inv_sqrt[src] * d_inv_sqrt[dst]
return norm
def propagate(x, edge_index, edge_weight_norm):
"""Efficient message passing with normalized edge weights"""
if edge_index.numel() == 0:
return torch.zeros_like(x)
src, dst = edge_index
ew = edge_weight_norm.to(x.dtype)
msg = x[src] * ew.unsqueeze(-1)
out = torch.zeros_like(x)
out.index_add_(0, dst, msg)
return out
class AtomEncoder(nn.Module):
def __init__(self, in_dim: int, hidden_size: int, dropout: float = 0.0):
super().__init__()
self.net = nn.Sequential(
nn.Linear(in_dim, hidden_size),
nn.RMSNorm(hidden_size),
nn.ReLU(),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class ShortGINE(nn.Module):
def __init__(self, in_dim, edge_dim, dropout=0.0):
super().__init__()
# Node MLP
node_mlp = nn.Sequential(
nn.Linear(in_dim, in_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(in_dim, in_dim)
)
# Edge MLP
edge_mlp = nn.Sequential(
nn.Linear(edge_dim, in_dim),
nn.ReLU(),
nn.Linear(in_dim, in_dim)
) if edge_dim > 0 else None
self.conv = GINConv(node_mlp, train_eps=True)
self.edge_mlp = edge_mlp
self.dropout = nn.Dropout(dropout)
self.norm = nn.RMSNorm(in_dim)
def forward(self, x, edge_index, edge_attr):
residual = x
# Pre-process edge features
if self.edge_mlp is not None and edge_attr is not None:
processed_edges = self.edge_mlp(edge_attr)
# Add edge info to messages
src, dst = edge_index
messages = x[src] + processed_edges
# Manual aggregation (GIN sum)
out = torch.zeros_like(x)
out.scatter_add_(0, dst.unsqueeze(-1).expand(-1, x.size(-1)), messages)
# GIN update: (1 + eps) * x + aggregated
eps = self.conv.eps if hasattr(self.conv, 'eps') else 0
combined = (1 + eps) * x + out
# Apply MLP
out = self.conv.nn(combined)
else:
out = self.conv(x, edge_index)
out = self.dropout(out) + residual
return out
class LongPoly(nn.Module):
def __init__(self, hidden_size, K=5, groups=4, dropout=0.1):
super().__init__()
assert hidden_size % groups == 0, "hidden_size must be divisible by groups"
self.K = K
self.groups = groups
self.group_channels = hidden_size // groups
# More efficient parameter structure
self.cheb_coeffs = nn.Parameter(torch.empty(groups, K + 1))
nn.init.xavier_uniform_(self.cheb_coeffs, gain=0.1)
# Add learnable scaling and bias per group
self.group_scale = nn.Parameter(torch.ones(groups))
self.group_bias = nn.Parameter(torch.zeros(groups))
# Lightweight normalization and activation
self.norm = nn.RMSNorm(hidden_size)
self.activation = nn.SiLU()
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
# Cache for computational efficiency
self.register_buffer('_cached_edge_index', None)
self.register_buffer('_cached_polynomials', None)
def forward(self, x, edge_index, edge_weight_norm):
N, H = x.shape
# Early return for empty graphs
if edge_index.numel() == 0:
x_grouped = x.view(N, self.groups, self.group_channels)
result = self.cheb_coeffs[:, 0].view(1, -1, 1) * x_grouped
result = result * self.group_scale.view(1, -1, 1) + self.group_bias.view(1, -1, 1)
return self.dropout(self.activation(self.norm(result.reshape(N, H))))
# Efficient Chebyshev computation without storing all polynomials
x_grouped = x.view(N, self.groups, self.group_channels)
result = self.cheb_coeffs[:, 0].view(1, -1, 1) * x_grouped # T_0 term
if self.K >= 1:
T_prev2 = x # T_0
T_prev1 = propagate(x, edge_index, edge_weight_norm) # T_1
# Add T_1 contribution
T1_grouped = T_prev1.view(N, self.groups, self.group_channels)
result += self.cheb_coeffs[:, 1].view(1, -1, 1) * T1_grouped
# Compute higher order terms on-the-fly
for k in range(2, self.K + 1):
T_curr = 2 * propagate(T_prev1, edge_index, edge_weight_norm) - T_prev2
T_curr_grouped = T_curr.view(N, self.groups, self.group_channels)
result += self.cheb_coeffs[:, k].view(1, -1, 1) * T_curr_grouped
# Update for next iteration
T_prev2, T_prev1 = T_prev1, T_curr
# Apply group-wise scaling and bias
result = result * self.group_scale.view(1, -1, 1) + self.group_bias.view(1, -1, 1)
# Final transformation
output = result.reshape(N, H)
return self.dropout(self.activation(self.norm(output)))
# ------------------------------
# Main GraphCliff Filter
# ------------------------------
class GraphCliffFilter(nn.Module):
def __init__(self,
hidden_size,
edge_dim,
groups=4,
short_dropout=0.1,
mid_K=3):
super().__init__()
self.groups = groups
# Normalization layers
self.pre_norm = nn.LayerNorm(hidden_size)
# Projection layer
self.proj = nn.Linear(hidden_size, 3 * hidden_size, bias=True)
nn.init.xavier_normal_(self.proj.weight, gain=1)
nn.init.zeros_(self.proj.bias)
self.short = ShortGINE(3 * hidden_size, edge_dim, short_dropout)
self.long = LongPoly(hidden_size, K=mid_K, groups=groups)
def forward(self, u, edge_index, edge_attr):
h = self.pre_norm(u)
z = self.proj(h) # [N, 3H]
# Short filter
z = self.short(z, edge_index, edge_attr)
# Split for different processing paths
x2, x1, v = torch.chunk(z, 3, dim=-1) # each [N, H]
# Compute edge normalization once
if edge_index.numel() > 0:
edge_weight = torch.ones(edge_index.size(1), device=edge_index.device, dtype=x2.dtype)
edge_norm = normalize_edges(u.size(0), edge_index, edge_weight)
else:
edge_norm = torch.tensor([], device=u.device, dtype=u.dtype)
# Long filter
mid_out = self.long(x2, edge_index, edge_norm)
# Gating
gate = torch.sigmoid(x1)
y = mid_out * gate + v
# Residual connection
z_in = y + u
return z_in
class GraphCliffEncoder(nn.Module):
def __init__(self,
hidden_size,
edge_dim,
num_layers=3,
groups=4,
mid_K=3,
dropout=0.1):
super().__init__()
self.layers = nn.ModuleList([
GraphCliffFilter(
hidden_size, edge_dim, groups,
short_dropout=dropout*0.5, mid_K=mid_K)
for _ in range(num_layers)
])
def forward(self, x, edge_index, edge_attr):
for layer in self.layers:
x = layer(x, edge_index, edge_attr)
return x
class GraphCliffRegressor(nn.Module):
def __init__(self,
atom_in_dim,
edge_dim,
hidden_size=256,
num_layers=3,
groups=4,
mid_K=3,
dropout=0.1):
super().__init__()
self.atom_encoder = AtomEncoder(atom_in_dim, hidden_size, dropout)
self.encoder = GraphCliffEncoder(
hidden_size, edge_dim, num_layers, groups, mid_K, dropout)
self.sagpool = SAGPooling(in_channels=hidden_size, ratio=0.8, GNN=SAGEConv)
reg_hidden = hidden_size *2
# Regression head
self.reg_head = nn.Sequential(
nn.Linear(reg_hidden, hidden_size // 2),
nn.LayerNorm(hidden_size // 2),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_size // 2, 1)
)
def forward(self, x_f, edge_index, edge_attr, batch):
# Atom encoding
x = self.atom_encoder(x_f)
# Graph encoding
x = self.encoder(x, edge_index, edge_attr)
x, edge_index, _, batch, _, _ = self.sagpool(x, edge_index, edge_attr, batch)
g = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)
# Regression
y = self.reg_head(g)
return y