88
99
1010class ResGatedGraphConvNetBase (GraphModelBase ):
11- """GNN that supports edge attributes"""
11+ """
12+ Residual Gated Graph Convolutional Network with edge attributes support.
13+
14+ This model uses a stack of `ResGatedGraphConv` layers from PyTorch Geometric,
15+ allowing edge attributes as part of message passing. A final projection layer maps
16+ to the hidden length specified for downstream graph prediction tasks.
17+ """
1218
1319 NAME = "ResGatedGraphConvNetBase"
1420
15- def __init__ (self , config , ** kwargs ):
21+ def __init__ (self , config : dict , ** kwargs ):
22+ """
23+ Initialize the ResGatedGraphConvNetBase.
24+
25+ Args:
26+ config (dict): Configuration dictionary with keys:
27+ - 'in_length' (int): Intermediate feature length used in GNN layers.
28+ - Other parameters inherited from GraphModelBase.
29+ **kwargs: Additional keyword arguments passed to GraphModelBase.
30+ """
1631 super ().__init__ (config = config , ** kwargs )
1732 self .in_length = int (config ["in_length" ])
1833
1934 self .activation = F .elu
2035 self .dropout = nn .Dropout (self .dropout_rate )
2136
22- self .convs = torch .nn .ModuleList ([] )
37+ self .convs = torch .nn .ModuleList ()
2338 for i in range (self .n_conv_layers ):
2439 if i == 0 :
40+ # Initial layer uses atom features as input
2541 self .convs .append (
2642 tgnn .ResGatedGraphConv (
2743 self .n_atom_properties ,
@@ -30,36 +46,66 @@ def __init__(self, config, **kwargs):
3046 edge_dim = self .n_bond_properties ,
3147 )
3248 )
49+ # Intermediate layers
3350 self .convs .append (
3451 tgnn .ResGatedGraphConv (
3552 self .in_length , self .in_length , edge_dim = self .n_bond_properties
3653 )
3754 )
55+
56+ # Final projection layer to hidden dimension
3857 self .final_conv = tgnn .ResGatedGraphConv (
3958 self .in_length , self .hidden_length , edge_dim = self .n_bond_properties
4059 )
4160
42- def forward (self , batch ):
61+ def forward (self , batch : dict ) -> torch .Tensor :
62+ """
63+ Forward pass through residual gated GNN layers.
64+
65+ Args:
66+ batch (dict): A batch containing:
67+ - 'features': A list with a `GraphData` instance as the first element.
68+
69+ Returns:
70+ torch.Tensor: Node-level embeddings of shape [num_nodes, hidden_length].
71+ """
4372 graph_data = batch ["features" ][0 ]
4473 assert isinstance (graph_data , GraphData )
45- a = graph_data . x . float ()
46- # a = self.embedding(a)
74+
75+ x = graph_data . x . float () # Atom features
4776
4877 for conv in self .convs :
4978 assert isinstance (conv , tgnn .ResGatedGraphConv )
50- a = self .activation (
51- conv (a , graph_data .edge_index .long (), edge_attr = graph_data .edge_attr )
79+ x = self .activation (
80+ conv (x , graph_data .edge_index .long (), edge_attr = graph_data .edge_attr )
5281 )
53- a = self .activation (
82+
83+ x = self .activation (
5484 self .final_conv (
55- a , graph_data .edge_index .long (), edge_attr = graph_data .edge_attr
85+ x , graph_data .edge_index .long (), edge_attr = graph_data .edge_attr
5686 )
5787 )
58- return a
88+
89+ return x
5990
6091
6192class ResGatedGraphPred (GraphNetWrapper ):
93+ """
94+ Residual Gated GNN for Graph Prediction.
95+
96+ Uses `ResGatedGraphConvNetBase` as the GNN encoder to compute node embeddings.
97+ """
98+
6299 NAME = "ResGatedGraphPred"
63100
64- def _get_gnn (self , config ):
101+ def _get_gnn (self , config : dict ) -> ResGatedGraphConvNetBase :
102+ """
103+ Instantiate the residual gated GNN backbone.
104+
105+ Args:
106+ config (dict): Model configuration.
107+
108+ Returns:
109+ ResGatedGraphConvNetBase: The GNN encoder.
110+ """
65111 return ResGatedGraphConvNetBase (config = config )
0 commit comments