Skip to content

Commit 823a333

Browse files
committed
doc model resgated and gat
1 parent 61473e2 commit 823a333

2 files changed

Lines changed: 109 additions & 17 deletions

File tree

chebai_graph/models/gat.py

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,28 @@
77

88

99
class GATGraphConvNetBase(GraphModelBase):
10-
def __init__(self, config, **kwargs):
10+
"""
11+
Graph Attention Network (GAT) base module for graph convolution.
12+
13+
Uses PyTorch Geometric's `GAT` implementation to process atomic node features
14+
and bond edge attributes through multiple attention heads and layers.
15+
"""
16+
17+
def __init__(self, config: dict, **kwargs):
18+
"""
19+
Initialize the GATGraphConvNetBase.
20+
21+
Args:
22+
config (dict): Model configuration containing:
23+
- 'heads' (int): Number of attention heads.
24+
- 'v2' (bool): Whether to use the GATv2 variant.
25+
- Other required GraphModelBase parameters.
26+
**kwargs: Additional arguments for the base class.
27+
"""
1128
super().__init__(config=config, **kwargs)
1229
self.heads = int(config["heads"])
1330
self.v2 = bool(config["v2"])
14-
self.activation = ELU() # instantiate once
31+
self.activation = ELU() # Instantiate ELU once for reuse.
1532
self.gat = GAT(
1633
in_channels=self.n_atom_properties,
1734
hidden_channels=self.hidden_length,
@@ -24,20 +41,49 @@ def __init__(self, config, **kwargs):
2441
)
2542

2643
def forward(self, batch: dict) -> torch.Tensor:
44+
"""
45+
Forward pass through the GAT network.
46+
47+
Processes atomic node features and edge attributes, and applies
48+
an ELU activation to the output.
49+
50+
Args:
51+
batch (dict): Input batch containing:
52+
- 'features': A list with a `GraphData` object as its first element.
53+
54+
Returns:
55+
torch.Tensor: Node embeddings after GAT and activation.
56+
"""
2757
graph_data = batch["features"][0]
2858
assert isinstance(graph_data, GraphData)
2959

30-
a = self.gat(
60+
out = self.gat(
3161
x=graph_data.x.float(),
3262
edge_index=graph_data.edge_index,
3363
edge_attr=graph_data.edge_attr,
3464
)
3565

36-
return self.activation(a)
66+
return self.activation(out)
3767

3868

3969
class GATGraphPred(GraphNetWrapper):
70+
"""
71+
GAT-based graph prediction model.
72+
73+
Uses a `GATGraphConvNetBase` as the GNN backbone for generating node embeddings,
74+
which are then pooled by the GraphNetWrapper for final prediction.
75+
"""
76+
4077
NAME = "GATGraphPred"
4178

42-
def _get_gnn(self, config):
79+
def _get_gnn(self, config: dict) -> GATGraphConvNetBase:
80+
"""
81+
Instantiate the GAT graph convolutional network base.
82+
83+
Args:
84+
config (dict): Model configuration.
85+
86+
Returns:
87+
GATGraphConvNetBase: The initialized GNN.
88+
"""
4389
return GATGraphConvNetBase(config=config)

chebai_graph/models/resgated.py

Lines changed: 58 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,36 @@
88

99

1010
class ResGatedGraphConvNetBase(GraphModelBase):
11-
"""GNN that supports edge attributes"""
11+
"""
12+
Residual Gated Graph Convolutional Network with edge attributes support.
13+
14+
This model uses a stack of `ResGatedGraphConv` layers from PyTorch Geometric,
15+
allowing edge attributes as part of message passing. A final projection layer maps
16+
to the hidden length specified for downstream graph prediction tasks.
17+
"""
1218

1319
NAME = "ResGatedGraphConvNetBase"
1420

15-
def __init__(self, config, **kwargs):
21+
def __init__(self, config: dict, **kwargs):
22+
"""
23+
Initialize the ResGatedGraphConvNetBase.
24+
25+
Args:
26+
config (dict): Configuration dictionary with keys:
27+
- 'in_length' (int): Intermediate feature length used in GNN layers.
28+
- Other parameters inherited from GraphModelBase.
29+
**kwargs: Additional keyword arguments passed to GraphModelBase.
30+
"""
1631
super().__init__(config=config, **kwargs)
1732
self.in_length = int(config["in_length"])
1833

1934
self.activation = F.elu
2035
self.dropout = nn.Dropout(self.dropout_rate)
2136

22-
self.convs = torch.nn.ModuleList([])
37+
self.convs = torch.nn.ModuleList()
2338
for i in range(self.n_conv_layers):
2439
if i == 0:
40+
# Initial layer uses atom features as input
2541
self.convs.append(
2642
tgnn.ResGatedGraphConv(
2743
self.n_atom_properties,
@@ -30,36 +46,66 @@ def __init__(self, config, **kwargs):
3046
edge_dim=self.n_bond_properties,
3147
)
3248
)
49+
# Intermediate layers
3350
self.convs.append(
3451
tgnn.ResGatedGraphConv(
3552
self.in_length, self.in_length, edge_dim=self.n_bond_properties
3653
)
3754
)
55+
56+
# Final projection layer to hidden dimension
3857
self.final_conv = tgnn.ResGatedGraphConv(
3958
self.in_length, self.hidden_length, edge_dim=self.n_bond_properties
4059
)
4160

42-
def forward(self, batch):
61+
def forward(self, batch: dict) -> torch.Tensor:
62+
"""
63+
Forward pass through residual gated GNN layers.
64+
65+
Args:
66+
batch (dict): A batch containing:
67+
- 'features': A list with a `GraphData` instance as the first element.
68+
69+
Returns:
70+
torch.Tensor: Node-level embeddings of shape [num_nodes, hidden_length].
71+
"""
4372
graph_data = batch["features"][0]
4473
assert isinstance(graph_data, GraphData)
45-
a = graph_data.x.float()
46-
# a = self.embedding(a)
74+
75+
x = graph_data.x.float() # Atom features
4776

4877
for conv in self.convs:
4978
assert isinstance(conv, tgnn.ResGatedGraphConv)
50-
a = self.activation(
51-
conv(a, graph_data.edge_index.long(), edge_attr=graph_data.edge_attr)
79+
x = self.activation(
80+
conv(x, graph_data.edge_index.long(), edge_attr=graph_data.edge_attr)
5281
)
53-
a = self.activation(
82+
83+
x = self.activation(
5484
self.final_conv(
55-
a, graph_data.edge_index.long(), edge_attr=graph_data.edge_attr
85+
x, graph_data.edge_index.long(), edge_attr=graph_data.edge_attr
5686
)
5787
)
58-
return a
88+
89+
return x
5990

6091

6192
class ResGatedGraphPred(GraphNetWrapper):
93+
"""
94+
Residual Gated GNN for Graph Prediction.
95+
96+
Uses `ResGatedGraphConvNetBase` as the GNN encoder to compute node embeddings.
97+
"""
98+
6299
NAME = "ResGatedGraphPred"
63100

64-
def _get_gnn(self, config):
101+
def _get_gnn(self, config: dict) -> ResGatedGraphConvNetBase:
102+
"""
103+
Instantiate the residual gated GNN backbone.
104+
105+
Args:
106+
config (dict): Model configuration.
107+
108+
Returns:
109+
ResGatedGraphConvNetBase: The GNN encoder.
110+
"""
65111
return ResGatedGraphConvNetBase(config=config)

0 commit comments

Comments
 (0)