From 39a5862ed5d6a8d8b4ff61cd9e614cb374e42732 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:35 -0500 Subject: [PATCH 01/62] added basic GRIT code Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/grit_pretraining.yaml | 57 ++++++ gridfm_graphkit/models/grit_transformer.py | 195 +++++++++++++++++++++ 2 files changed, 252 insertions(+) create mode 100644 examples/config/grit_pretraining.yaml create mode 100644 gridfm_graphkit/models/grit_transformer.py diff --git a/examples/config/grit_pretraining.yaml b/examples/config/grit_pretraining.yaml new file mode 100644 index 0000000..a6566e0 --- /dev/null +++ b/examples/config/grit_pretraining.yaml @@ -0,0 +1,57 @@ +callbacks: + patience: 100 + tol: 0 +data: + baseMVA: 100 + learn_mask: false + mask_dim: 6 + mask_ratio: 0.5 + mask_type: rnd + mask_value: -1.0 + networks: + # - Texas2k_case1_2016summerpeak + - case24_ieee_rts + # - case118_ieee + # - case300_ieee + - case89_pegase + # - case240_pserc + normalization: baseMVAnorm + scenarios: + # - 5000 + - 5000 + - 5000 + # - 30000 + # - 50000 + # - 50000 + test_ratio: 0.1 + val_ratio: 0.1 + workers: 4 +model: + attention_head: 8 + dropout: 0.1 + edge_dim: 2 + hidden_size: 123 + input_dim: 9 + num_layers: 14 + output_dim: 6 + pe_dim: 20 + type: GPSTransformer # +optimizer: + beta1: 0.9 + beta2: 0.999 + learning_rate: 0.0001 + lr_decay: 0.7 + lr_patience: 10 +seed: 0 +training: + batch_size: 8 + epochs: 500 + loss_weights: + - 0.01 + - 0.99 + losses: + - MaskedMSE + - PBE + accelerator: auto + devices: auto + strategy: auto diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py new file mode 100644 index 0000000..3ee5e8e --- /dev/null +++ b/gridfm_graphkit/models/grit_transformer.py @@ -0,0 +1,195 @@ +from gridfm_graphkit.io.registries import MODELS_REGISTRY +from torch import nn +import torch +import torch_geometric.graphgym.register as register +from torch_geometric.graphgym.config import cfg +from torch_geometric.graphgym.models.gnn import GNNPreMP +from torch_geometric.graphgym.models.layer import (new_layer_config, + BatchNorm1dNode) +from torch_geometric.graphgym.register import register_network +from torch_geometric.graphgym.models.layer import new_layer_config, MLP + + + +class FeatureEncoder(torch.nn.Module): + """ + Encoding node and edge features + + Args: + dim_in (int): Input feature dimension + """ + def __init__(self, dim_in): + super(FeatureEncoder, self).__init__() + self.dim_in = dim_in + if cfg.dataset.node_encoder: + # Encode integer node features via nn.Embeddings + NodeEncoder = register.node_encoder_dict[ + cfg.dataset.node_encoder_name] + self.node_encoder = NodeEncoder(cfg.gnn.dim_inner) + if cfg.dataset.node_encoder_bn: + self.node_encoder_bn = BatchNorm1dNode( + new_layer_config(cfg.gnn.dim_inner, -1, -1, has_act=False, + has_bias=False, cfg=cfg)) + # Update dim_in to reflect the new dimension fo the node features + self.dim_in = cfg.gnn.dim_inner + if cfg.dataset.edge_encoder: + # Hard-limit max edge dim for PNA. + if 'PNA' in cfg.gt.layer_type: + cfg.gnn.dim_edge = min(128, cfg.gnn.dim_inner) + else: + cfg.gnn.dim_edge = cfg.gnn.dim_inner + # Encode integer edge features via nn.Embeddings + EdgeEncoder = register.edge_encoder_dict[ + cfg.dataset.edge_encoder_name] + self.edge_encoder = EdgeEncoder(cfg.gnn.dim_edge) + if cfg.dataset.edge_encoder_bn: + self.edge_encoder_bn = BatchNorm1dNode( + new_layer_config(cfg.gnn.dim_edge, -1, -1, has_act=False, + has_bias=False, cfg=cfg)) + + def forward(self, batch): + for module in self.children(): + batch = module(batch) + return batch + + +@register_head('decoder_head') +class GNNDecoderHead(nn.Module): + """ + Predictoin head for encoder-decoder networks. + + Args: + dim_in (int): Input dimension # TODO update arg comments as needed + dim_out (int): Output dimension. For binary prediction, dim_out=1. + """ + + def __init__(self, dim_in, dim_out): + super(GNNDecoderHead, self).__init__() + + + + # note that the input and output dimensions are from the config file + # if we want this to be variable that will have to change with + # each layer + + # TODO consider use of a bottleneck + + # note the config is imported as in other modules + + # the number of config layers should apriori be different than the encoder + + + global_model_type = cfg.gt.get('layer_type', "GritTransformer") + + TransformerLayer = register.layer_dict.get(global_model_type) + + layers = [] + for l in range(cfg.gnn.layers_decode): + layers.append(TransformerLayer( + in_dim=cfg.gt.dim_hidden, + out_dim=cfg.gt.dim_hidden, + num_heads=cfg.gt.n_heads, + dropout=cfg.gt.dropout, # TODO could migrate this and others to gnn in config + act=cfg.gnn.act, + attn_dropout=cfg.gt.attn_dropout, + layer_norm=cfg.gt.layer_norm, + batch_norm=cfg.gt.batch_norm, + residual=True, + norm_e=cfg.gt.attn.norm_e, + O_e=cfg.gt.attn.O_e, + cfg=cfg.gt, + )) + # layers = [] + + self.layers = torch.nn.Sequential(*layers) + + + + self.layer_post_mp = MLP( + new_layer_config(dim_in, dim_out, cfg.gnn.layers_post_mp, + has_act=False, has_bias=True, cfg=cfg)) + + + + def _apply_index(self, batch): + return batch.x, batch.y + + def forward(self, batch): + batch = self.layers(batch) + + # follow GMAE here and make a final linear projection from the + # hiden dimension to the output dimension + batch = self.layer_post_mp(batch) + + pred, label = self._apply_index(batch) + #print('>>>>>>', pred.size(),label.size()) + return pred, label + + + +@MODELS_REGISTRY.register("GRIT") +class GritTransformer(torch.nn.Module): + ''' + The proposed GritTransformer (Graph Inductive Bias Transformer) + ''' + + def __init__(self, dim_in, dim_out): + super().__init__() + self.encoder = FeatureEncoder(dim_in) + dim_in = self.encoder.dim_in + + self.ablation = True + self.ablation = False + + if cfg.posenc_RRWP.enable: + self.rrwp_abs_encoder = register.node_encoder_dict["rrwp_linear"]\ + (cfg.posenc_RRWP.ksteps, cfg.gnn.dim_inner) + rel_pe_dim = cfg.posenc_RRWP.ksteps + self.rrwp_rel_encoder = register.edge_encoder_dict["rrwp_linear"] \ + (rel_pe_dim, cfg.gnn.dim_edge, + pad_to_full_graph=cfg.gt.attn.full_attn, + add_node_attr_as_self_loop=False, + fill_value=0. + ) + + + if cfg.gnn.layers_pre_mp > 0: + self.pre_mp = GNNPreMP( + dim_in, cfg.gnn.dim_inner, cfg.gnn.layers_pre_mp) + dim_in = cfg.gnn.dim_inner + + assert cfg.gt.dim_hidden == cfg.gnn.dim_inner == dim_in, \ + "The inner and hidden dims must match." + + global_model_type = cfg.gt.get('layer_type', "GritTransformer") + # global_model_type = "GritTransformer" + + TransformerLayer = register.layer_dict.get(global_model_type) + + layers = [] + for l in range(cfg.gt.layers): + layers.append(TransformerLayer( + in_dim=cfg.gt.dim_hidden, + out_dim=cfg.gt.dim_hidden, + num_heads=cfg.gt.n_heads, + dropout=cfg.gt.dropout, + act=cfg.gnn.act, + attn_dropout=cfg.gt.attn_dropout, + layer_norm=cfg.gt.layer_norm, + batch_norm=cfg.gt.batch_norm, + residual=True, + norm_e=cfg.gt.attn.norm_e, + O_e=cfg.gt.attn.O_e, + cfg=cfg.gt, + )) + # layers = [] + + self.layers = torch.nn.Sequential(*layers) + GNNHead = register.head_dict[cfg.gnn.head] + self.post_mp = GNNHead(dim_in=cfg.gnn.dim_inner, dim_out=dim_out) + + def forward(self, batch): + for module in self.children(): + batch = module(batch) + + return batch \ No newline at end of file From 922d6cefe51a19d31cff2b3cda09603e2600a349 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:36 -0500 Subject: [PATCH 02/62] initial connection of model to config Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/grit_pretraining.yaml | 35 +++- gridfm_graphkit/models/grit_transformer.py | 202 ++++++++------------- 2 files changed, 113 insertions(+), 124 deletions(-) diff --git a/examples/config/grit_pretraining.yaml b/examples/config/grit_pretraining.yaml index a6566e0..904d6dc 100644 --- a/examples/config/grit_pretraining.yaml +++ b/examples/config/grit_pretraining.yaml @@ -32,10 +32,41 @@ model: edge_dim: 2 hidden_size: 123 input_dim: 9 - num_layers: 14 + num_layers: 10 output_dim: 6 pe_dim: 20 - type: GPSTransformer # + type: GRIT #GPSTransformer # + layers_pre_mp: 0 + act: relu + encoder: + node_encoder: True + edge_encoder: True + node_encoder_name: TODO + node_encoder_bn: True + gt: + layer_type: GritTransformer + # layers: 10 + # n_heads: 8 + dim_hidden: 64 # `gt.dim_hidden` must match `gnn.dim_inner` + # dropout: 0.0 + layer_norm: False + batch_norm: True + update_e: True + attn_dropout: 0.2 + attn: + clamp: 5. + act: 'relu' + full_attn: True + edge_enhance: True + O_e: True + norm_e: True + signed_sqrt: True + posenc_RRWP: + enable: True + ksteps: 21 + add_identity: True + add_node_attr: False + add_inverse: False optimizer: beta1: 0.9 beta2: 0.999 diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index 3ee5e8e..b09f527 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -1,12 +1,10 @@ from gridfm_graphkit.io.registries import MODELS_REGISTRY from torch import nn import torch -import torch_geometric.graphgym.register as register -from torch_geometric.graphgym.config import cfg + from torch_geometric.graphgym.models.gnn import GNNPreMP from torch_geometric.graphgym.models.layer import (new_layer_config, BatchNorm1dNode) -from torch_geometric.graphgym.register import register_network from torch_geometric.graphgym.models.layer import new_layer_config, MLP @@ -17,114 +15,49 @@ class FeatureEncoder(torch.nn.Module): Args: dim_in (int): Input feature dimension + + + TODO replace 'register' with local version of it + """ - def __init__(self, dim_in): + def __init__( + self, + dim_in, + dim_inner, + args + ): super(FeatureEncoder, self).__init__() self.dim_in = dim_in - if cfg.dataset.node_encoder: + if args.node_encoder: # Encode integer node features via nn.Embeddings NodeEncoder = register.node_encoder_dict[ - cfg.dataset.node_encoder_name] - self.node_encoder = NodeEncoder(cfg.gnn.dim_inner) - if cfg.dataset.node_encoder_bn: + args.node_encoder_name] + self.node_encoder = NodeEncoder(dim_inner) + if args.node_encoder_bn: self.node_encoder_bn = BatchNorm1dNode( - new_layer_config(cfg.gnn.dim_inner, -1, -1, has_act=False, + new_layer_config(dim_inner, -1, -1, has_act=False, has_bias=False, cfg=cfg)) # Update dim_in to reflect the new dimension fo the node features - self.dim_in = cfg.gnn.dim_inner - if cfg.dataset.edge_encoder: + self.dim_in = dim_inner + if args.edge_encoder: # Hard-limit max edge dim for PNA. - if 'PNA' in cfg.gt.layer_type: - cfg.gnn.dim_edge = min(128, cfg.gnn.dim_inner) + if 'PNA' in args.model.gt.layer_type: # TODO remove condition if PNA not needed + dim_edge = min(128, dim_inner) else: - cfg.gnn.dim_edge = cfg.gnn.dim_inner + dim_edge = dim_inner # Encode integer edge features via nn.Embeddings EdgeEncoder = register.edge_encoder_dict[ cfg.dataset.edge_encoder_name] - self.edge_encoder = EdgeEncoder(cfg.gnn.dim_edge) + self.edge_encoder = EdgeEncoder(dim_edge) if cfg.dataset.edge_encoder_bn: self.edge_encoder_bn = BatchNorm1dNode( - new_layer_config(cfg.gnn.dim_edge, -1, -1, has_act=False, + new_layer_config(dim_edge, -1, -1, has_act=False, has_bias=False, cfg=cfg)) def forward(self, batch): for module in self.children(): batch = module(batch) return batch - - -@register_head('decoder_head') -class GNNDecoderHead(nn.Module): - """ - Predictoin head for encoder-decoder networks. - - Args: - dim_in (int): Input dimension # TODO update arg comments as needed - dim_out (int): Output dimension. For binary prediction, dim_out=1. - """ - - def __init__(self, dim_in, dim_out): - super(GNNDecoderHead, self).__init__() - - - - # note that the input and output dimensions are from the config file - # if we want this to be variable that will have to change with - # each layer - - # TODO consider use of a bottleneck - - # note the config is imported as in other modules - - # the number of config layers should apriori be different than the encoder - - - global_model_type = cfg.gt.get('layer_type', "GritTransformer") - - TransformerLayer = register.layer_dict.get(global_model_type) - - layers = [] - for l in range(cfg.gnn.layers_decode): - layers.append(TransformerLayer( - in_dim=cfg.gt.dim_hidden, - out_dim=cfg.gt.dim_hidden, - num_heads=cfg.gt.n_heads, - dropout=cfg.gt.dropout, # TODO could migrate this and others to gnn in config - act=cfg.gnn.act, - attn_dropout=cfg.gt.attn_dropout, - layer_norm=cfg.gt.layer_norm, - batch_norm=cfg.gt.batch_norm, - residual=True, - norm_e=cfg.gt.attn.norm_e, - O_e=cfg.gt.attn.O_e, - cfg=cfg.gt, - )) - # layers = [] - - self.layers = torch.nn.Sequential(*layers) - - - - self.layer_post_mp = MLP( - new_layer_config(dim_in, dim_out, cfg.gnn.layers_post_mp, - has_act=False, has_bias=True, cfg=cfg)) - - - - def _apply_index(self, batch): - return batch.x, batch.y - - def forward(self, batch): - batch = self.layers(batch) - - # follow GMAE here and make a final linear projection from the - # hiden dimension to the output dimension - batch = self.layer_post_mp(batch) - - pred, label = self._apply_index(batch) - #print('>>>>>>', pred.size(),label.size()) - return pred, label - @MODELS_REGISTRY.register("GRIT") @@ -133,60 +66,85 @@ class GritTransformer(torch.nn.Module): The proposed GritTransformer (Graph Inductive Bias Transformer) ''' - def __init__(self, dim_in, dim_out): + def __init__(self, args): super().__init__() - self.encoder = FeatureEncoder(dim_in) - dim_in = self.encoder.dim_in - self.ablation = True - self.ablation = False + # ### TODO remove default args not needed #### + # self.input_dim = + # self.hidden_dim = + # self.output_dim = + # self.edge_dim = + # self.num_layers = args.model.num_layers + # self.heads = getattr(args.model, "attention_head", 1) + # self.dropout = getattr(args.model, "dropout", 0.0) + # ### ### + + dim_in = args.model.input_dim + dim_out = args.model.output_dim + dim_inner = args.model.hidden_size + dim_edge = args.model.edge_dim + num_heads = args.model.attention_head + dropout = args.model.dropout + num_layers = args.model.num_layers + + self.encoder = FeatureEncoder( + dim_in, + dim_inner, + args.model.encoder + ) # TODO add args + dim_in = self.encoder.dim_in + - if cfg.posenc_RRWP.enable: + if args.model.posenc_RRWP.enable: + # TODO connect 'register' to local version self.rrwp_abs_encoder = register.node_encoder_dict["rrwp_linear"]\ - (cfg.posenc_RRWP.ksteps, cfg.gnn.dim_inner) - rel_pe_dim = cfg.posenc_RRWP.ksteps + (args.model.posenc_RRWP.ksteps, dim_inner) + rel_pe_dim = args.model.posenc_RRWP.ksteps self.rrwp_rel_encoder = register.edge_encoder_dict["rrwp_linear"] \ - (rel_pe_dim, cfg.gnn.dim_edge, - pad_to_full_graph=cfg.gt.attn.full_attn, + (rel_pe_dim, dim_edge, + pad_to_full_graph=args.model.gt.attn.full_attn, add_node_attr_as_self_loop=False, fill_value=0. ) - if cfg.gnn.layers_pre_mp > 0: + if args.model.layers_pre_mp > 0: self.pre_mp = GNNPreMP( - dim_in, cfg.gnn.dim_inner, cfg.gnn.layers_pre_mp) - dim_in = cfg.gnn.dim_inner + dim_in, dim_inner, args.model.layers_pre_mp) + dim_in = dim_inner - assert cfg.gt.dim_hidden == cfg.gnn.dim_inner == dim_in, \ + assert args.model.hidden_size == dim_inner == dim_in, \ "The inner and hidden dims must match." - global_model_type = cfg.gt.get('layer_type', "GritTransformer") + global_model_type = args.model.gt.layer_type # global_model_type = "GritTransformer" - + # TODO replace this with local register logic TransformerLayer = register.layer_dict.get(global_model_type) layers = [] - for l in range(cfg.gt.layers): + for ll in range(num_layers): layers.append(TransformerLayer( - in_dim=cfg.gt.dim_hidden, - out_dim=cfg.gt.dim_hidden, - num_heads=cfg.gt.n_heads, - dropout=cfg.gt.dropout, - act=cfg.gnn.act, - attn_dropout=cfg.gt.attn_dropout, - layer_norm=cfg.gt.layer_norm, - batch_norm=cfg.gt.batch_norm, + in_dim=args.model.gt.dim_hidden, + out_dim=args.model.gt.dim_hidden, + num_heads=num_heads, + dropout=dropout, + act=args.model.act, + attn_dropout=args.model.gt.attn_dropout, + layer_norm=args.model.gt.layer_norm, + batch_norm=args.model.gt.batch_norm, residual=True, - norm_e=cfg.gt.attn.norm_e, - O_e=cfg.gt.attn.O_e, - cfg=cfg.gt, + norm_e=args.model.gt.attn.norm_e, + O_e=args.model.gt.attn.O_e, + cfg=args.model.gt, )) - # layers = [] - self.layers = torch.nn.Sequential(*layers) - GNNHead = register.head_dict[cfg.gnn.head] - self.post_mp = GNNHead(dim_in=cfg.gnn.dim_inner, dim_out=dim_out) + self.layers = nn.Sequential(*layers) + + self.decoder = nn.Sequential( + nn.Linear(dim_inner, dim_inner), + nn.LeakyReLU(), + nn.Linear(dim_inner, dim_out), + ) def forward(self, batch): for module in self.children(): From e8281ac03c1f97fed0b57419b96665fc4fb90020 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:36 -0500 Subject: [PATCH 03/62] collect model components and replace old register method Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/models/grit_layer.py | 346 +++++++++++++++++++++ gridfm_graphkit/models/grit_transformer.py | 79 +++-- gridfm_graphkit/models/rrwp_encoder.py | 192 ++++++++++++ 3 files changed, 583 insertions(+), 34 deletions(-) create mode 100644 gridfm_graphkit/models/grit_layer.py create mode 100644 gridfm_graphkit/models/rrwp_encoder.py diff --git a/gridfm_graphkit/models/grit_layer.py b/gridfm_graphkit/models/grit_layer.py new file mode 100644 index 0000000..dc6cf97 --- /dev/null +++ b/gridfm_graphkit/models/grit_layer.py @@ -0,0 +1,346 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch_geometric as pyg +from torch_geometric.utils.num_nodes import maybe_num_nodes +from torch_scatter import scatter, scatter_max, scatter_add + +from grit.utils import negate_edge_index +from torch_geometric.graphgym.register import * +import opt_einsum as oe + +from yacs.config import CfgNode as CN + +import warnings + +def pyg_softmax(src, index, num_nodes=None): + r"""Computes a sparsely evaluated softmax. + Given a value tensor :attr:`src`, this function first groups the values + along the first dimension based on the indices specified in :attr:`index`, + and then proceeds to compute the softmax individually for each group. + + Args: + src (Tensor): The source tensor. + index (LongTensor): The indices of elements for applying the softmax. + num_nodes (int, optional): The number of nodes, *i.e.* + :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`) + + :rtype: :class:`Tensor` + """ + + num_nodes = maybe_num_nodes(index, num_nodes) + + out = src - scatter_max(src, index, dim=0, dim_size=num_nodes)[0][index] + out = out.exp() + out = out / ( + scatter_add(out, index, dim=0, dim_size=num_nodes)[index] + 1e-16) + + return out + + + +class MultiHeadAttentionLayerGritSparse(nn.Module): + """ + Proposed Attention Computation for GRIT + """ + + def __init__(self, in_dim, out_dim, num_heads, use_bias, + clamp=5., dropout=0., act=None, + edge_enhance=True, + sqrt_relu=False, + signed_sqrt=True, + cfg=CN(), + **kwargs): + super().__init__() + + self.out_dim = out_dim + self.num_heads = num_heads + self.dropout = nn.Dropout(dropout) + self.clamp = np.abs(clamp) if clamp is not None else None + self.edge_enhance = edge_enhance + + self.Q = nn.Linear(in_dim, out_dim * num_heads, bias=True) + self.K = nn.Linear(in_dim, out_dim * num_heads, bias=use_bias) + self.E = nn.Linear(in_dim, out_dim * num_heads * 2, bias=True) + self.V = nn.Linear(in_dim, out_dim * num_heads, bias=use_bias) + nn.init.xavier_normal_(self.Q.weight) + nn.init.xavier_normal_(self.K.weight) + nn.init.xavier_normal_(self.E.weight) + nn.init.xavier_normal_(self.V.weight) + + self.Aw = nn.Parameter(torch.zeros(self.out_dim, self.num_heads, 1), requires_grad=True) + nn.init.xavier_normal_(self.Aw) + + if act is None: + self.act = nn.Identity() + else: + self.act = act_dict[act]() + + if self.edge_enhance: + self.VeRow = nn.Parameter(torch.zeros(self.out_dim, self.num_heads, self.out_dim), requires_grad=True) + nn.init.xavier_normal_(self.VeRow) + + def propagate_attention(self, batch): + src = batch.K_h[batch.edge_index[0]] # (num relative) x num_heads x out_dim + dest = batch.Q_h[batch.edge_index[1]] # (num relative) x num_heads x out_dim + score = src + dest # element-wise multiplication + + if batch.get("E", None) is not None: + batch.E = batch.E.view(-1, self.num_heads, self.out_dim * 2) + E_w, E_b = batch.E[:, :, :self.out_dim], batch.E[:, :, self.out_dim:] + # (num relative) x num_heads x out_dim + score = score * E_w + score = torch.sqrt(torch.relu(score)) - torch.sqrt(torch.relu(-score)) + score = score + E_b + + score = self.act(score) + e_t = score + + # output edge + if batch.get("E", None) is not None: + batch.wE = score.flatten(1) + + # final attn + score = oe.contract("ehd, dhc->ehc", score, self.Aw, backend="torch") + if self.clamp is not None: + score = torch.clamp(score, min=-self.clamp, max=self.clamp) + + raw_attn = score + score = pyg_softmax(score, batch.edge_index[1]) # (num relative) x num_heads x 1 + score = self.dropout(score) + batch.attn = score + + # Aggregate with Attn-Score + msg = batch.V_h[batch.edge_index[0]] * score # (num relative) x num_heads x out_dim + batch.wV = torch.zeros_like(batch.V_h) # (num nodes in batch) x num_heads x out_dim + scatter(msg, batch.edge_index[1], dim=0, out=batch.wV, reduce='add') + + if self.edge_enhance and batch.E is not None: + rowV = scatter(e_t * score, batch.edge_index[1], dim=0, reduce="add") + rowV = oe.contract("nhd, dhc -> nhc", rowV, self.VeRow, backend="torch") + batch.wV = batch.wV + rowV + + def forward(self, batch): + Q_h = self.Q(batch.x) + K_h = self.K(batch.x) + + V_h = self.V(batch.x) + if batch.get("edge_attr", None) is not None: + batch.E = self.E(batch.edge_attr) + else: + batch.E = None + + batch.Q_h = Q_h.view(-1, self.num_heads, self.out_dim) + batch.K_h = K_h.view(-1, self.num_heads, self.out_dim) + batch.V_h = V_h.view(-1, self.num_heads, self.out_dim) + self.propagate_attention(batch) + h_out = batch.wV + e_out = batch.get('wE', None) + + return h_out, e_out + + +@register_layer("GritTransformer") +class GritTransformerLayer(nn.Module): + """ + Proposed Transformer Layer for GRIT + """ + def __init__(self, in_dim, out_dim, num_heads, + dropout=0.0, + attn_dropout=0.0, + layer_norm=False, batch_norm=True, + residual=True, + act='relu', + norm_e=True, + O_e=True, + cfg=dict(), + **kwargs): + super().__init__() + + self.debug = False + self.in_channels = in_dim + self.out_channels = out_dim + self.in_dim = in_dim + self.out_dim = out_dim + self.num_heads = num_heads + self.dropout = dropout + self.residual = residual + self.layer_norm = layer_norm + self.batch_norm = batch_norm + + # ------- + self.update_e = cfg.get("update_e", True) + self.bn_momentum = cfg.bn_momentum + self.bn_no_runner = cfg.bn_no_runner + self.rezero = cfg.get("rezero", False) + + self.act = act_dict[act]() if act is not None else nn.Identity() + if cfg.get("attn", None) is None: + cfg.attn = dict() + self.use_attn = cfg.attn.get("use", True) + # self.sigmoid_deg = cfg.attn.get("sigmoid_deg", False) + self.deg_scaler = cfg.attn.get("deg_scaler", True) + + self.attention = MultiHeadAttentionLayerGritSparse( + in_dim=in_dim, + out_dim=out_dim // num_heads, + num_heads=num_heads, + use_bias=cfg.attn.get("use_bias", False), + dropout=attn_dropout, + clamp=cfg.attn.get("clamp", 5.), + act=cfg.attn.get("act", "relu"), + edge_enhance=cfg.attn.get("edge_enhance", True), + sqrt_relu=cfg.attn.get("sqrt_relu", False), + signed_sqrt=cfg.attn.get("signed_sqrt", False), + scaled_attn =cfg.attn.get("scaled_attn", False), + no_qk=cfg.attn.get("no_qk", False), + ) + + if cfg.attn.get('graphormer_attn', False): + self.attention = MultiHeadAttentionLayerGraphormerSparse( + in_dim=in_dim, + out_dim=out_dim // num_heads, + num_heads=num_heads, + use_bias=cfg.attn.get("use_bias", False), + dropout=attn_dropout, + clamp=cfg.attn.get("clamp", 5.), + act=cfg.attn.get("act", "relu"), + edge_enhance=True, + sqrt_relu=cfg.attn.get("sqrt_relu", False), + signed_sqrt=cfg.attn.get("signed_sqrt", False), + scaled_attn =cfg.attn.get("scaled_attn", False), + no_qk=cfg.attn.get("no_qk", False), + ) + + + + self.O_h = nn.Linear(out_dim//num_heads * num_heads, out_dim) + if O_e: + self.O_e = nn.Linear(out_dim//num_heads * num_heads, out_dim) + else: + self.O_e = nn.Identity() + + # -------- Deg Scaler Option ------ + + if self.deg_scaler: + self.deg_coef = nn.Parameter(torch.zeros(1, out_dim//num_heads * num_heads, 2)) + nn.init.xavier_normal_(self.deg_coef) + + if self.layer_norm: + self.layer_norm1_h = nn.LayerNorm(out_dim) + self.layer_norm1_e = nn.LayerNorm(out_dim) if norm_e else nn.Identity() + + if self.batch_norm: + # when the batch_size is really small, use smaller momentum to avoid bad mini-batch leading to extremely bad val/test loss (NaN) + self.batch_norm1_h = nn.BatchNorm1d(out_dim, track_running_stats=not self.bn_no_runner, eps=1e-5, momentum=cfg.bn_momentum) + self.batch_norm1_e = nn.BatchNorm1d(out_dim, track_running_stats=not self.bn_no_runner, eps=1e-5, momentum=cfg.bn_momentum) if norm_e else nn.Identity() + + # FFN for h + self.FFN_h_layer1 = nn.Linear(out_dim, out_dim * 2) + self.FFN_h_layer2 = nn.Linear(out_dim * 2, out_dim) + + if self.layer_norm: + self.layer_norm2_h = nn.LayerNorm(out_dim) + + if self.batch_norm: + self.batch_norm2_h = nn.BatchNorm1d(out_dim, track_running_stats=not self.bn_no_runner, eps=1e-5, momentum=cfg.bn_momentum) + + if self.rezero: + self.alpha1_h = nn.Parameter(torch.zeros(1,1)) + self.alpha2_h = nn.Parameter(torch.zeros(1,1)) + self.alpha1_e = nn.Parameter(torch.zeros(1,1)) + + def forward(self, batch): + h = batch.x + num_nodes = batch.num_nodes + log_deg = get_log_deg(batch) + + h_in1 = h # for first residual connection + e_in1 = batch.get("edge_attr", None) + e = None + # multi-head attention out + + h_attn_out, e_attn_out = self.attention(batch) + + h = h_attn_out.view(num_nodes, -1) + h = F.dropout(h, self.dropout, training=self.training) + + # degree scaler + if self.deg_scaler: + h = torch.stack([h, h * log_deg], dim=-1) + h = (h * self.deg_coef).sum(dim=-1) + + h = self.O_h(h) + if e_attn_out is not None: + e = e_attn_out.flatten(1) + e = F.dropout(e, self.dropout, training=self.training) + e = self.O_e(e) + + if self.residual: + if self.rezero: h = h * self.alpha1_h + h = h_in1 + h # residual connection + if e is not None: + if self.rezero: e = e * self.alpha1_e + e = e + e_in1 + + if self.layer_norm: + h = self.layer_norm1_h(h) + if e is not None: e = self.layer_norm1_e(e) + + if self.batch_norm: + h = self.batch_norm1_h(h) + if e is not None: e = self.batch_norm1_e(e) + + # FFN for h + h_in2 = h # for second residual connection + h = self.FFN_h_layer1(h) + h = self.act(h) + h = F.dropout(h, self.dropout, training=self.training) + h = self.FFN_h_layer2(h) + + if self.residual: + if self.rezero: h = h * self.alpha2_h + h = h_in2 + h # residual connection + + if self.layer_norm: + h = self.layer_norm2_h(h) + + if self.batch_norm: + h = self.batch_norm2_h(h) + + batch.x = h + if self.update_e: + batch.edge_attr = e + else: + batch.edge_attr = e_in1 + + return batch + + def __repr__(self): + return '{}(in_channels={}, out_channels={}, heads={}, residual={})\n[{}]'.format( + self.__class__.__name__, + self.in_channels, + self.out_channels, self.num_heads, self.residual, + super().__repr__(), + ) + + +@torch.no_grad() +def get_log_deg(batch): + if "log_deg" in batch: + log_deg = batch.log_deg + elif "deg" in batch: + deg = batch.deg + log_deg = torch.log(deg + 1).unsqueeze(-1) + else: + warnings.warn("Compute the degree on the fly; Might be problematric if have applied edge-padding to complete graphs") + deg = pyg.utils.degree(batch.edge_index[1], + num_nodes=batch.num_nodes, + dtype=torch.float + ) + log_deg = torch.log(deg + 1) + log_deg = log_deg.view(batch.num_nodes, 1) + return log_deg + + diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index b09f527..e3c6047 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -2,12 +2,42 @@ from torch import nn import torch +from rrwp_encoder import RRWPLinearNodeEncoder, RRWPLinearEdgeEncoder +from grit_layer import GritTransformerLayer + +# TODO verify use from torch_geometric.graphgym.models.gnn import GNNPreMP from torch_geometric.graphgym.models.layer import (new_layer_config, BatchNorm1dNode) from torch_geometric.graphgym.models.layer import new_layer_config, MLP +class LinearNodeEncoder(torch.nn.Module): + def __init__(self, emb_dim): + super().__init__() + + self.encoder = torch.nn.Linear(cfg.share.dim_in, emb_dim) + + def forward(self, batch): + batch.x = self.encoder(batch.x) + return batch + +class LinearEdgeEncoder(torch.nn.Module): + def __init__(self, emb_dim): + super().__init__() + if cfg.dataset.name in ['MNIST', 'CIFAR10']: + self.in_dim = 1 + elif cfg.dataset.name.startswith('attributed_triangle-'): + self.in_dim = 2 + else: + raise ValueError("Input edge feature dim is required to be hardset " + "or refactored to use a cfg option.") + self.encoder = torch.nn.Linear(self.in_dim, emb_dim) + + def forward(self, batch): + batch.edge_attr = self.encoder(batch.edge_attr.view(-1, self.in_dim)) + return batch + class FeatureEncoder(torch.nn.Module): """ @@ -16,9 +46,6 @@ class FeatureEncoder(torch.nn.Module): Args: dim_in (int): Input feature dimension - - TODO replace 'register' with local version of it - """ def __init__( self, @@ -30,9 +57,7 @@ def __init__( self.dim_in = dim_in if args.node_encoder: # Encode integer node features via nn.Embeddings - NodeEncoder = register.node_encoder_dict[ - args.node_encoder_name] - self.node_encoder = NodeEncoder(dim_inner) + self.node_encoder = LinearNodeEncoder(dim_inner) if args.node_encoder_bn: self.node_encoder_bn = BatchNorm1dNode( new_layer_config(dim_inner, -1, -1, has_act=False, @@ -46,9 +71,7 @@ def __init__( else: dim_edge = dim_inner # Encode integer edge features via nn.Embeddings - EdgeEncoder = register.edge_encoder_dict[ - cfg.dataset.edge_encoder_name] - self.edge_encoder = EdgeEncoder(dim_edge) + self.edge_encoder = LinearEdgeEncoder(dim_edge) if cfg.dataset.edge_encoder_bn: self.edge_encoder_bn = BatchNorm1dNode( new_layer_config(dim_edge, -1, -1, has_act=False, @@ -65,19 +88,9 @@ class GritTransformer(torch.nn.Module): ''' The proposed GritTransformer (Graph Inductive Bias Transformer) ''' - def __init__(self, args): super().__init__() - # ### TODO remove default args not needed #### - # self.input_dim = - # self.hidden_dim = - # self.output_dim = - # self.edge_dim = - # self.num_layers = args.model.num_layers - # self.heads = getattr(args.model, "attention_head", 1) - # self.dropout = getattr(args.model, "dropout", 0.0) - # ### ### dim_in = args.model.input_dim dim_out = args.model.output_dim @@ -96,16 +109,19 @@ def __init__(self, args): if args.model.posenc_RRWP.enable: - # TODO connect 'register' to local version - self.rrwp_abs_encoder = register.node_encoder_dict["rrwp_linear"]\ - (args.model.posenc_RRWP.ksteps, dim_inner) + + self.rrwp_abs_encoder = RRWPLinearNodeEncoder( + args.model.posenc_RRWP.ksteps, + dim_inner + ) rel_pe_dim = args.model.posenc_RRWP.ksteps - self.rrwp_rel_encoder = register.edge_encoder_dict["rrwp_linear"] \ - (rel_pe_dim, dim_edge, - pad_to_full_graph=args.model.gt.attn.full_attn, - add_node_attr_as_self_loop=False, - fill_value=0. - ) + self.rrwp_rel_encoder = RRWPLinearNodeEncoder( + rel_pe_dim, + dim_edge, + pad_to_full_graph=args.model.gt.attn.full_attn, + add_node_attr_as_self_loop=False, + fill_value=0. + ) if args.model.layers_pre_mp > 0: @@ -116,14 +132,9 @@ def __init__(self, args): assert args.model.hidden_size == dim_inner == dim_in, \ "The inner and hidden dims must match." - global_model_type = args.model.gt.layer_type - # global_model_type = "GritTransformer" - # TODO replace this with local register logic - TransformerLayer = register.layer_dict.get(global_model_type) - layers = [] for ll in range(num_layers): - layers.append(TransformerLayer( + layers.append(GritTransformerLayer( in_dim=args.model.gt.dim_hidden, out_dim=args.model.gt.dim_hidden, num_heads=num_heads, diff --git a/gridfm_graphkit/models/rrwp_encoder.py b/gridfm_graphkit/models/rrwp_encoder.py new file mode 100644 index 0000000..f98118e --- /dev/null +++ b/gridfm_graphkit/models/rrwp_encoder.py @@ -0,0 +1,192 @@ +''' + The RRWP encoder for GRIT (ours) +''' +import torch +from torch import nn +from torch.nn import functional as F +from ogb.utils.features import get_bond_feature_dims +import torch_sparse + +import torch_geometric as pyg +from torch_geometric.graphgym.register import ( + register_edge_encoder, + register_node_encoder, +) + +from torch_geometric.utils import remove_self_loops, add_remaining_self_loops, add_self_loops +from torch_scatter import scatter +import warnings + +def full_edge_index(edge_index, batch=None): + """ + Retunr the Full batched sparse adjacency matrices given by edge indices. + Returns batched sparse adjacency matrices with exactly those edges that + are not in the input `edge_index` while ignoring self-loops. + Implementation inspired by `torch_geometric.utils.to_dense_adj` + Args: + edge_index: The edge indices. + batch: Batch vector, which assigns each node to a specific example. + Returns: + Complementary edge index. + """ + + if batch is None: + batch = edge_index.new_zeros(edge_index.max().item() + 1) + + batch_size = batch.max().item() + 1 + one = batch.new_ones(batch.size(0)) + num_nodes = scatter(one, batch, + dim=0, dim_size=batch_size, reduce='add') + cum_nodes = torch.cat([batch.new_zeros(1), num_nodes.cumsum(dim=0)]) + + negative_index_list = [] + for i in range(batch_size): + n = num_nodes[i].item() + size = [n, n] + adj = torch.ones(size, dtype=torch.short, + device=edge_index.device) + + adj = adj.view(size) + _edge_index = adj.nonzero(as_tuple=False).t().contiguous() + # _edge_index, _ = remove_self_loops(_edge_index) + negative_index_list.append(_edge_index + cum_nodes[i]) + + edge_index_full = torch.cat(negative_index_list, dim=1).contiguous() + return edge_index_full + + + +class RRWPLinearNodeEncoder(torch.nn.Module): + """ + FC_1(RRWP) + FC_2 (Node-attr) + note: FC_2 is given by the Typedict encoder of node-attr in some cases + Parameters: + num_classes - the number of classes for the embedding mapping to learn + """ + def __init__(self, emb_dim, out_dim, use_bias=False, batchnorm=False, layernorm=False, pe_name="rrwp"): + super().__init__() + self.batchnorm = batchnorm + self.layernorm = layernorm + self.name = pe_name + + self.fc = nn.Linear(emb_dim, out_dim, bias=use_bias) + torch.nn.init.xavier_uniform_(self.fc.weight) + + if self.batchnorm: + self.bn = nn.BatchNorm1d(out_dim) + if self.layernorm: + self.ln = nn.LayerNorm(out_dim) + + def forward(self, batch): + # Encode just the first dimension if more exist + rrwp = batch[f"{self.name}"] + rrwp = self.fc(rrwp) + + if self.batchnorm: + rrwp = self.bn(rrwp) + + if self.layernorm: + rrwp = self.ln(rrwp) + + if "x" in batch: + batch.x = batch.x + rrwp + else: + batch.x = rrwp + + return batch + + +class RRWPLinearEdgeEncoder(torch.nn.Module): + ''' + Merge RRWP with given edge-attr and Zero-padding to all pairs of node + FC_1(RRWP) + FC_2(edge-attr) + - FC_2 given by the TypedictEncoder in same cases + - Zero-padding for non-existing edges in fully-connected graph + - (optional) add node-attr as the E_{i,i}'s attr + note: assuming node-attr and edge-attr is with the same dimension after Encoders + ''' + def __init__(self, emb_dim, out_dim, batchnorm=False, layernorm=False, use_bias=False, + pad_to_full_graph=True, fill_value=0., + add_node_attr_as_self_loop=False, + overwrite_old_attr=False): + super().__init__() + # note: batchnorm/layernorm might ruin some properties of pe on providing shortest-path distance info + self.emb_dim = emb_dim + self.out_dim = out_dim + self.add_node_attr_as_self_loop = add_node_attr_as_self_loop + self.overwrite_old_attr=overwrite_old_attr # remove the old edge-attr + + self.batchnorm = batchnorm + self.layernorm = layernorm + if self.batchnorm or self.layernorm: + warnings.warn("batchnorm/layernorm might ruin some properties of pe on providing shortest-path distance info ") + + self.fc = nn.Linear(emb_dim, out_dim, bias=use_bias) + torch.nn.init.xavier_uniform_(self.fc.weight) + self.pad_to_full_graph = pad_to_full_graph + self.fill_value = 0. + + padding = torch.ones(1, out_dim, dtype=torch.float) * fill_value + self.register_buffer("padding", padding) + + if self.batchnorm: + self.bn = nn.BatchNorm1d(out_dim) + + if self.layernorm: + self.ln = nn.LayerNorm(out_dim) + + def forward(self, batch): + rrwp_idx = batch.rrwp_index + rrwp_val = batch.rrwp_val + edge_index = batch.edge_index + edge_attr = batch.edge_attr + rrwp_val = self.fc(rrwp_val) + + if edge_attr is None: + edge_attr = edge_index.new_zeros(edge_index.size(1), rrwp_val.size(1)) + # zero padding for non-existing edges + + if self.overwrite_old_attr: + out_idx, out_val = rrwp_idx, rrwp_val + else: + # edge_index, edge_attr = add_remaining_self_loops(edge_index, edge_attr, num_nodes=batch.num_nodes, fill_value=0.) + edge_index, edge_attr = add_self_loops(edge_index, edge_attr, num_nodes=batch.num_nodes, fill_value=0.) + + #print('-->>>>', edge_attr.size(), rrwp_val.size()) + out_idx, out_val = torch_sparse.coalesce( + torch.cat([edge_index, rrwp_idx], dim=1), + torch.cat([edge_attr, rrwp_val], dim=0), + batch.num_nodes, batch.num_nodes, + op="add" + ) + + + if self.pad_to_full_graph: + edge_index_full = full_edge_index(out_idx, batch=batch.batch) + edge_attr_pad = self.padding.repeat(edge_index_full.size(1), 1) + # zero padding to fully-connected graphs + out_idx = torch.cat([out_idx, edge_index_full], dim=1) + out_val = torch.cat([out_val, edge_attr_pad], dim=0) + out_idx, out_val = torch_sparse.coalesce( + out_idx, out_val, batch.num_nodes, batch.num_nodes, + op="add" + ) + + if self.batchnorm: + out_val = self.bn(out_val) + + if self.layernorm: + out_val = self.ln(out_val) + + + batch.edge_index, batch.edge_attr = out_idx, out_val + return batch + + def __repr__(self): + return f"{self.__class__.__name__}" \ + f"(pad_to_full_graph={self.pad_to_full_graph}," \ + f"fill_value={self.fill_value}," \ + f"{self.fc.__repr__()})" + + + From a67e52285b69d5fae828b5c6231a586daa6d629e Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:36 -0500 Subject: [PATCH 04/62] clean up imported layers and encoders Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/grit_pretraining.yaml | 6 +-- gridfm_graphkit/models/grit_layer.py | 3 -- gridfm_graphkit/models/grit_transformer.py | 50 +++++++++++----------- gridfm_graphkit/models/rrwp_encoder.py | 14 ++---- 4 files changed, 29 insertions(+), 44 deletions(-) diff --git a/examples/config/grit_pretraining.yaml b/examples/config/grit_pretraining.yaml index 904d6dc..dc6f3a1 100644 --- a/examples/config/grit_pretraining.yaml +++ b/examples/config/grit_pretraining.yaml @@ -35,8 +35,7 @@ model: num_layers: 10 output_dim: 6 pe_dim: 20 - type: GRIT #GPSTransformer # - layers_pre_mp: 0 + type: GRIT act: relu encoder: node_encoder: True @@ -45,10 +44,7 @@ model: node_encoder_bn: True gt: layer_type: GritTransformer - # layers: 10 - # n_heads: 8 dim_hidden: 64 # `gt.dim_hidden` must match `gnn.dim_inner` - # dropout: 0.0 layer_norm: False batch_norm: True update_e: True diff --git a/gridfm_graphkit/models/grit_layer.py b/gridfm_graphkit/models/grit_layer.py index dc6cf97..b477980 100644 --- a/gridfm_graphkit/models/grit_layer.py +++ b/gridfm_graphkit/models/grit_layer.py @@ -6,8 +6,6 @@ from torch_geometric.utils.num_nodes import maybe_num_nodes from torch_scatter import scatter, scatter_max, scatter_add -from grit.utils import negate_edge_index -from torch_geometric.graphgym.register import * import opt_einsum as oe from yacs.config import CfgNode as CN @@ -141,7 +139,6 @@ def forward(self, batch): return h_out, e_out -@register_layer("GritTransformer") class GritTransformerLayer(nn.Module): """ Proposed Transformer Layer for GRIT diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index e3c6047..715c25f 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -1,15 +1,29 @@ from gridfm_graphkit.io.registries import MODELS_REGISTRY -from torch import nn import torch - +from torch import nn from rrwp_encoder import RRWPLinearNodeEncoder, RRWPLinearEdgeEncoder from grit_layer import GritTransformerLayer -# TODO verify use -from torch_geometric.graphgym.models.gnn import GNNPreMP -from torch_geometric.graphgym.models.layer import (new_layer_config, - BatchNorm1dNode) -from torch_geometric.graphgym.models.layer import new_layer_config, MLP + + +class BatchNorm1dNode(torch.nn.Module): + r"""A batch normalization layer for node-level features. + + Args: + dim_in (int): BatchNorm input dimension. + TODO fill in comments + """ + def __init__(self, dim_in, eps, momentum): + super().__init__() + self.bn = torch.nn.BatchNorm1d( + dim_in, + eps=eps, + momentum=momentum, + ) + + def forward(self, batch): + batch.x = self.bn(batch.x) + return batch class LinearNodeEncoder(torch.nn.Module): @@ -59,23 +73,16 @@ def __init__( # Encode integer node features via nn.Embeddings self.node_encoder = LinearNodeEncoder(dim_inner) if args.node_encoder_bn: - self.node_encoder_bn = BatchNorm1dNode( - new_layer_config(dim_inner, -1, -1, has_act=False, - has_bias=False, cfg=cfg)) + self.node_encoder_bn = BatchNorm1dNode(dim_inner, 1e-5, 0.1) # Update dim_in to reflect the new dimension fo the node features self.dim_in = dim_inner if args.edge_encoder: - # Hard-limit max edge dim for PNA. - if 'PNA' in args.model.gt.layer_type: # TODO remove condition if PNA not needed - dim_edge = min(128, dim_inner) - else: - dim_edge = dim_inner + + dim_edge = dim_inner # Encode integer edge features via nn.Embeddings self.edge_encoder = LinearEdgeEncoder(dim_edge) if cfg.dataset.edge_encoder_bn: - self.edge_encoder_bn = BatchNorm1dNode( - new_layer_config(dim_edge, -1, -1, has_act=False, - has_bias=False, cfg=cfg)) + self.edge_encoder_bn = BatchNorm1dNode(dim_edge, 1e-5, 0.1) def forward(self, batch): for module in self.children(): @@ -107,7 +114,6 @@ def __init__(self, args): ) # TODO add args dim_in = self.encoder.dim_in - if args.model.posenc_RRWP.enable: self.rrwp_abs_encoder = RRWPLinearNodeEncoder( @@ -123,12 +129,6 @@ def __init__(self, args): fill_value=0. ) - - if args.model.layers_pre_mp > 0: - self.pre_mp = GNNPreMP( - dim_in, dim_inner, args.model.layers_pre_mp) - dim_in = dim_inner - assert args.model.hidden_size == dim_inner == dim_in, \ "The inner and hidden dims must match." diff --git a/gridfm_graphkit/models/rrwp_encoder.py b/gridfm_graphkit/models/rrwp_encoder.py index f98118e..b73e463 100644 --- a/gridfm_graphkit/models/rrwp_encoder.py +++ b/gridfm_graphkit/models/rrwp_encoder.py @@ -1,25 +1,18 @@ -''' +""" The RRWP encoder for GRIT (ours) -''' +""" import torch from torch import nn from torch.nn import functional as F -from ogb.utils.features import get_bond_feature_dims import torch_sparse -import torch_geometric as pyg -from torch_geometric.graphgym.register import ( - register_edge_encoder, - register_node_encoder, -) - from torch_geometric.utils import remove_self_loops, add_remaining_self_loops, add_self_loops from torch_scatter import scatter import warnings def full_edge_index(edge_index, batch=None): """ - Retunr the Full batched sparse adjacency matrices given by edge indices. + Return the Full batched sparse adjacency matrices given by edge indices. Returns batched sparse adjacency matrices with exactly those edges that are not in the input `edge_index` while ignoring self-loops. Implementation inspired by `torch_geometric.utils.to_dense_adj` @@ -152,7 +145,6 @@ def forward(self, batch): # edge_index, edge_attr = add_remaining_self_loops(edge_index, edge_attr, num_nodes=batch.num_nodes, fill_value=0.) edge_index, edge_attr = add_self_loops(edge_index, edge_attr, num_nodes=batch.num_nodes, fill_value=0.) - #print('-->>>>', edge_attr.size(), rrwp_val.size()) out_idx, out_val = torch_sparse.coalesce( torch.cat([edge_index, rrwp_idx], dim=1), torch.cat([edge_attr, rrwp_val], dim=0), From 6966f5ffc56a2384a2cb730324aa1bbf8269efb7 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:37 -0500 Subject: [PATCH 05/62] flow in basic structure for RRWP calculation Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/grit_pretraining.yaml | 13 +- gridfm_graphkit/datasets/posenc_stats.py | 423 ++++++++++++++++++ .../datasets/powergrid_datamodule.py | 14 + gridfm_graphkit/datasets/rrwp.py | 103 +++++ 4 files changed, 547 insertions(+), 6 deletions(-) create mode 100644 gridfm_graphkit/datasets/posenc_stats.py create mode 100644 gridfm_graphkit/datasets/rrwp.py diff --git a/examples/config/grit_pretraining.yaml b/examples/config/grit_pretraining.yaml index dc6f3a1..05f31be 100644 --- a/examples/config/grit_pretraining.yaml +++ b/examples/config/grit_pretraining.yaml @@ -26,6 +26,12 @@ data: test_ratio: 0.1 val_ratio: 0.1 workers: 4 + posenc_RRWP: # TODO maybe better with data section... + enable: True + ksteps: 21 + add_identity: True + add_node_attr: False + add_inverse: False model: attention_head: 8 dropout: 0.1 @@ -57,12 +63,7 @@ model: O_e: True norm_e: True signed_sqrt: True - posenc_RRWP: - enable: True - ksteps: 21 - add_identity: True - add_node_attr: False - add_inverse: False + optimizer: beta1: 0.9 beta2: 0.999 diff --git a/gridfm_graphkit/datasets/posenc_stats.py b/gridfm_graphkit/datasets/posenc_stats.py new file mode 100644 index 0000000..492a0a6 --- /dev/null +++ b/gridfm_graphkit/datasets/posenc_stats.py @@ -0,0 +1,423 @@ +from copy import deepcopy + +import numpy as np +import torch +import torch.nn.functional as F + +from torch_geometric.utils import (get_laplacian, to_scipy_sparse_matrix, + to_undirected, to_dense_adj) +from torch_geometric.utils.num_nodes import maybe_num_nodes +from torch_scatter import scatter_add +from functools import partial +from gridfm_graphkit.datasets.rrwp import add_full_rrwp + + +def compute_posenc_stats(data, pe_types, is_undirected, cfg): + """Precompute positional encodings for the given graph. + Supported PE statistics to precompute, selected by `pe_types`: + 'LapPE': Laplacian eigen-decomposition. + 'RWSE': Random walk landing probabilities (diagonals of RW matrices). + 'HKfullPE': Full heat kernels and their diagonals. (NOT IMPLEMENTED) + 'HKdiagSE': Diagonals of heat kernel diffusion. + 'ElstaticSE': Kernel based on the electrostatic interaction between nodes. + 'RRWP': Relative Random Walk Probabilities PE (Ours, for GRIT) + Args: + data: PyG graph + pe_types: Positional encoding types to precompute statistics for. + This can also be a combination, e.g. 'eigen+rw_landing' + is_undirected: True if the graph is expected to be undirected + cfg: Main configuration node + + Returns: + Extended PyG Data object. + """ + # Verify PE types. + for t in pe_types: + if t not in ['LapPE', 'EquivStableLapPE', 'SignNet', + 'RWSE', 'HKdiagSE', 'HKfullPE', 'ElstaticSE','RRWP']: + raise ValueError(f"Unexpected PE stats selection {t} in {pe_types}") + + # Basic preprocessing of the input graph. + if hasattr(data, 'num_nodes'): + N = data.num_nodes # Explicitly given number of nodes, e.g. ogbg-ppa + else: + N = data.x.shape[0] # Number of nodes, including disconnected nodes. + laplacian_norm_type = cfg.posenc_LapPE.eigen.laplacian_norm.lower() + if laplacian_norm_type == 'none': + laplacian_norm_type = None + if is_undirected: + undir_edge_index = data.edge_index + else: + undir_edge_index = to_undirected(data.edge_index) + + # Eigen values and vectors. + evals, evects = None, None + if 'LapPE' in pe_types or 'EquivStableLapPE' in pe_types: + # Eigen-decomposition with numpy, can be reused for Heat kernels. + L = to_scipy_sparse_matrix( + *get_laplacian(undir_edge_index, normalization=laplacian_norm_type, + num_nodes=N) + ) + evals, evects = np.linalg.eigh(L.toarray()) + + if 'LapPE' in pe_types: + max_freqs=cfg.posenc_LapPE.eigen.max_freqs + eigvec_norm=cfg.posenc_LapPE.eigen.eigvec_norm + elif 'EquivStableLapPE' in pe_types: + max_freqs=cfg.posenc_EquivStableLapPE.eigen.max_freqs + eigvec_norm=cfg.posenc_EquivStableLapPE.eigen.eigvec_norm + + data.EigVals, data.EigVecs = get_lap_decomp_stats( + evals=evals, evects=evects, + max_freqs=max_freqs, + eigvec_norm=eigvec_norm) + + if 'SignNet' in pe_types: + # Eigen-decomposition with numpy for SignNet. + norm_type = cfg.posenc_SignNet.eigen.laplacian_norm.lower() + if norm_type == 'none': + norm_type = None + L = to_scipy_sparse_matrix( + *get_laplacian(undir_edge_index, normalization=norm_type, + num_nodes=N) + ) + evals_sn, evects_sn = np.linalg.eigh(L.toarray()) + data.eigvals_sn, data.eigvecs_sn = get_lap_decomp_stats( + evals=evals_sn, evects=evects_sn, + max_freqs=cfg.posenc_SignNet.eigen.max_freqs, + eigvec_norm=cfg.posenc_SignNet.eigen.eigvec_norm) + + # Random Walks. + if 'RWSE' in pe_types: + kernel_param = cfg.posenc_RWSE.kernel + if len(kernel_param.times) == 0: + raise ValueError("List of kernel times required for RWSE") + rw_landing = get_rw_landing_probs(ksteps=kernel_param.times, + edge_index=data.edge_index, + num_nodes=N) + data.pestat_RWSE = rw_landing + + # Heat Kernels. + if 'HKdiagSE' in pe_types or 'HKfullPE' in pe_types: + # Get the eigenvalues and eigenvectors of the regular Laplacian, + # if they have not yet been computed for 'eigen'. + if laplacian_norm_type is not None or evals is None or evects is None: + L_heat = to_scipy_sparse_matrix( + *get_laplacian(undir_edge_index, normalization=None, num_nodes=N) + ) + evals_heat, evects_heat = np.linalg.eigh(L_heat.toarray()) + else: + evals_heat, evects_heat = evals, evects + evals_heat = torch.from_numpy(evals_heat) + evects_heat = torch.from_numpy(evects_heat) + + # Get the full heat kernels. + if 'HKfullPE' in pe_types: + # The heat kernels can't be stored in the Data object without + # additional padding because in PyG's collation of the graphs the + # sizes of tensors must match except in dimension 0. Do this when + # the full heat kernels are actually used downstream by an Encoder. + raise NotImplementedError() + # heat_kernels, hk_diag = get_heat_kernels(evects_heat, evals_heat, + # kernel_times=kernel_param.times) + # data.pestat_HKdiagSE = hk_diag + # Get heat kernel diagonals in more efficient way. + if 'HKdiagSE' in pe_types: + kernel_param = cfg.posenc_HKdiagSE.kernel + if len(kernel_param.times) == 0: + raise ValueError("Diffusion times are required for heat kernel") + hk_diag = get_heat_kernels_diag(evects_heat, evals_heat, + kernel_times=kernel_param.times, + space_dim=0) + data.pestat_HKdiagSE = hk_diag + + # Electrostatic interaction inspired kernel. + if 'ElstaticSE' in pe_types: + elstatic = get_electrostatic_function_encoding(undir_edge_index, N) + data.pestat_ElstaticSE = elstatic + + if 'RRWP' in pe_types: + param = cfg.posenc_RRWP + transform = partial(add_full_rrwp, + walk_length=param.ksteps, + attr_name_abs="rrwp", + attr_name_rel="rrwp", + add_identity=True, + spd=param.spd, # by default False + ) + data = transform(data) + + return data + + +def get_lap_decomp_stats(evals, evects, max_freqs, eigvec_norm='L2'): + """Compute Laplacian eigen-decomposition-based PE stats of the given graph. + + Args: + evals, evects: Precomputed eigen-decomposition + max_freqs: Maximum number of top smallest frequencies / eigenvecs to use + eigvec_norm: Normalization for the eigen vectors of the Laplacian + Returns: + Tensor (num_nodes, max_freqs, 1) eigenvalues repeated for each node + Tensor (num_nodes, max_freqs) of eigenvector values per node + """ + N = len(evals) # Number of nodes, including disconnected nodes. + + # Keep up to the maximum desired number of frequencies. + idx = evals.argsort()[:max_freqs] + evals, evects = evals[idx], np.real(evects[:, idx]) + evals = torch.from_numpy(np.real(evals)).clamp_min(0) + + # Normalize and pad eigen vectors. + evects = torch.from_numpy(evects).float() + evects = eigvec_normalizer(evects, evals, normalization=eigvec_norm) + if N < max_freqs: + EigVecs = F.pad(evects, (0, max_freqs - N), value=float('nan')) + else: + EigVecs = evects + + # Pad and save eigenvalues. + if N < max_freqs: + EigVals = F.pad(evals, (0, max_freqs - N), value=float('nan')).unsqueeze(0) + else: + EigVals = evals.unsqueeze(0) + EigVals = EigVals.repeat(N, 1).unsqueeze(2) + + return EigVals, EigVecs + + +def get_rw_landing_probs(ksteps, edge_index, edge_weight=None, + num_nodes=None, space_dim=0): + """Compute Random Walk landing probabilities for given list of K steps. + + Args: + ksteps: List of k-steps for which to compute the RW landings + edge_index: PyG sparse representation of the graph + edge_weight: (optional) Edge weights + num_nodes: (optional) Number of nodes in the graph + space_dim: (optional) Estimated dimensionality of the space. Used to + correct the random-walk diagonal by a factor `k^(space_dim/2)`. + In euclidean space, this correction means that the height of + the gaussian distribution stays almost constant across the number of + steps, if `space_dim` is the dimension of the euclidean space. + + Returns: + 2D Tensor with shape (num_nodes, len(ksteps)) with RW landing probs + """ + if edge_weight is None: + edge_weight = torch.ones(edge_index.size(1), device=edge_index.device) + num_nodes = maybe_num_nodes(edge_index, num_nodes) + source, dest = edge_index[0], edge_index[1] + deg = scatter_add(edge_weight, source, dim=0, dim_size=num_nodes) # Out degrees. + deg_inv = deg.pow(-1.) + deg_inv.masked_fill_(deg_inv == float('inf'), 0) + + if edge_index.numel() == 0: + P = edge_index.new_zeros((1, num_nodes, num_nodes)) + else: + # P = D^-1 * A + P = torch.diag(deg_inv) @ to_dense_adj(edge_index, max_num_nodes=num_nodes) # 1 x (Num nodes) x (Num nodes) + rws = [] + if ksteps == list(range(min(ksteps), max(ksteps) + 1)): + # Efficient way if ksteps are a consecutive sequence (most of the time the case) + Pk = P.clone().detach().matrix_power(min(ksteps)) + for k in range(min(ksteps), max(ksteps) + 1): + rws.append(torch.diagonal(Pk, dim1=-2, dim2=-1) * \ + (k ** (space_dim / 2))) + Pk = Pk @ P + else: + # Explicitly raising P to power k for each k \in ksteps. + for k in ksteps: + rws.append(torch.diagonal(P.matrix_power(k), dim1=-2, dim2=-1) * \ + (k ** (space_dim / 2))) + rw_landing = torch.cat(rws, dim=0).transpose(0, 1) # (Num nodes) x (K steps) + + return rw_landing + + +def get_heat_kernels_diag(evects, evals, kernel_times=[], space_dim=0): + """Compute Heat kernel diagonal. + + This is a continuous function that represents a Gaussian in the Euclidean + space, and is the solution to the diffusion equation. + The random-walk diagonal should converge to this. + + Args: + evects: Eigenvectors of the Laplacian matrix + evals: Eigenvalues of the Laplacian matrix + kernel_times: Time for the diffusion. Analogous to the k-steps in random + walk. The time is equivalent to the variance of the kernel. + space_dim: (optional) Estimated dimensionality of the space. Used to + correct the diffusion diagonal by a factor `t^(space_dim/2)`. In + euclidean space, this correction means that the height of the + gaussian stays constant across time, if `space_dim` is the dimension + of the euclidean space. + + Returns: + 2D Tensor with shape (num_nodes, len(ksteps)) with RW landing probs + """ + heat_kernels_diag = [] + if len(kernel_times) > 0: + evects = F.normalize(evects, p=2., dim=0) + + # Remove eigenvalues == 0 from the computation of the heat kernel + idx_remove = evals < 1e-8 + evals = evals[~idx_remove] + evects = evects[:, ~idx_remove] + + # Change the shapes for the computations + evals = evals.unsqueeze(-1) # lambda_{i, ..., ...} + evects = evects.transpose(0, 1) # phi_{i,j}: i-th eigvec X j-th node + + # Compute the heat kernels diagonal only for each time + eigvec_mul = evects ** 2 + for t in kernel_times: + # sum_{i>0}(exp(-2 t lambda_i) * phi_{i, j} * phi_{i, j}) + this_kernel = torch.sum(torch.exp(-t * evals) * eigvec_mul, + dim=0, keepdim=False) + + # Multiply by `t` to stabilize the values, since the gaussian height + # is proportional to `1/t` + heat_kernels_diag.append(this_kernel * (t ** (space_dim / 2))) + heat_kernels_diag = torch.stack(heat_kernels_diag, dim=0).transpose(0, 1) + + return heat_kernels_diag + + +def get_heat_kernels(evects, evals, kernel_times=[]): + """Compute full Heat diffusion kernels. + + Args: + evects: Eigenvectors of the Laplacian matrix + evals: Eigenvalues of the Laplacian matrix + kernel_times: Time for the diffusion. Analogous to the k-steps in random + walk. The time is equivalent to the variance of the kernel. + """ + heat_kernels, rw_landing = [], [] + if len(kernel_times) > 0: + evects = F.normalize(evects, p=2., dim=0) + + # Remove eigenvalues == 0 from the computation of the heat kernel + idx_remove = evals < 1e-8 + evals = evals[~idx_remove] + evects = evects[:, ~idx_remove] + + # Change the shapes for the computations + evals = evals.unsqueeze(-1).unsqueeze(-1) # lambda_{i, ..., ...} + evects = evects.transpose(0, 1) # phi_{i,j}: i-th eigvec X j-th node + + # Compute the heat kernels for each time + eigvec_mul = (evects.unsqueeze(2) * evects.unsqueeze(1)) # (phi_{i, j1, ...} * phi_{i, ..., j2}) + for t in kernel_times: + # sum_{i>0}(exp(-2 t lambda_i) * phi_{i, j1, ...} * phi_{i, ..., j2}) + heat_kernels.append( + torch.sum(torch.exp(-t * evals) * eigvec_mul, + dim=0, keepdim=False) + ) + + heat_kernels = torch.stack(heat_kernels, dim=0) # (Num kernel times) x (Num nodes) x (Num nodes) + + # Take the diagonal of each heat kernel, + # i.e. the landing probability of each of the random walks + rw_landing = torch.diagonal(heat_kernels, dim1=-2, dim2=-1).transpose(0, 1) # (Num nodes) x (Num kernel times) + + return heat_kernels, rw_landing + + +def get_electrostatic_function_encoding(edge_index, num_nodes): + """Kernel based on the electrostatic interaction between nodes. + """ + L = to_scipy_sparse_matrix( + *get_laplacian(edge_index, normalization=None, num_nodes=num_nodes) + ).todense() + L = torch.as_tensor(L) + Dinv = torch.eye(L.shape[0]) * (L.diag() ** -1) + A = deepcopy(L).abs() + A.fill_diagonal_(0) + DinvA = Dinv.matmul(A) + + electrostatic = torch.pinverse(L) + electrostatic = electrostatic - electrostatic.diag() + green_encoding = torch.stack([ + electrostatic.min(dim=0)[0], # Min of Vi -> j + electrostatic.max(dim=0)[0], # Max of Vi -> j + electrostatic.mean(dim=0), # Mean of Vi -> j + electrostatic.std(dim=0), # Std of Vi -> j + electrostatic.min(dim=1)[0], # Min of Vj -> i + electrostatic.max(dim=0)[0], # Max of Vj -> i + electrostatic.mean(dim=1), # Mean of Vj -> i + electrostatic.std(dim=1), # Std of Vj -> i + (DinvA * electrostatic).sum(dim=0), # Mean of interaction on direct neighbour + (DinvA * electrostatic).sum(dim=1), # Mean of interaction from direct neighbour + ], dim=1) + + return green_encoding + + +def eigvec_normalizer(EigVecs, EigVals, normalization="L2", eps=1e-12): + """ + Implement different eigenvector normalizations. + """ + + EigVals = EigVals.unsqueeze(0) + + if normalization == "L1": + # L1 normalization: eigvec / sum(abs(eigvec)) + denom = EigVecs.norm(p=1, dim=0, keepdim=True) + + elif normalization == "L2": + # L2 normalization: eigvec / sqrt(sum(eigvec^2)) + denom = EigVecs.norm(p=2, dim=0, keepdim=True) + + elif normalization == "abs-max": + # AbsMax normalization: eigvec / max|eigvec| + denom = torch.max(EigVecs.abs(), dim=0, keepdim=True).values + + elif normalization == "wavelength": + # AbsMax normalization, followed by wavelength multiplication: + # eigvec * pi / (2 * max|eigvec| * sqrt(eigval)) + denom = torch.max(EigVecs.abs(), dim=0, keepdim=True).values + eigval_denom = torch.sqrt(EigVals) + eigval_denom[EigVals < eps] = 1 # Problem with eigval = 0 + denom = denom * eigval_denom * 2 / np.pi + + elif normalization == "wavelength-asin": + # AbsMax normalization, followed by arcsin and wavelength multiplication: + # arcsin(eigvec / max|eigvec|) / sqrt(eigval) + denom_temp = torch.max(EigVecs.abs(), dim=0, keepdim=True).values.clamp_min(eps).expand_as(EigVecs) + EigVecs = torch.asin(EigVecs / denom_temp) + eigval_denom = torch.sqrt(EigVals) + eigval_denom[EigVals < eps] = 1 # Problem with eigval = 0 + denom = eigval_denom + + elif normalization == "wavelength-soft": + # AbsSoftmax normalization, followed by wavelength multiplication: + # eigvec / (softmax|eigvec| * sqrt(eigval)) + denom = (F.softmax(EigVecs.abs(), dim=0) * EigVecs.abs()).sum(dim=0, keepdim=True) + eigval_denom = torch.sqrt(EigVals) + eigval_denom[EigVals < eps] = 1 # Problem with eigval = 0 + denom = denom * eigval_denom + + else: + raise ValueError(f"Unsupported normalization `{normalization}`") + + denom = denom.clamp_min(eps).expand_as(EigVecs) + EigVecs = EigVecs / denom + + return EigVecs + +from torch_geometric.transforms import BaseTransform +from torch_geometric.data import Data, HeteroData + +class ComputePosencStat(BaseTransform): + def __init__(self, pe_types, is_undirected, cfg): + self.pe_types = pe_types + self.is_undirected = is_undirected + self.cfg = cfg + + def __call__(self, data: Data) -> Data: + data = compute_posenc_stats(data, pe_types=self.pe_types, + is_undirected=self.is_undirected, + cfg=self.cfg + ) + return data \ No newline at end of file diff --git a/gridfm_graphkit/datasets/powergrid_datamodule.py b/gridfm_graphkit/datasets/powergrid_datamodule.py index c18c360..ad68f4f 100644 --- a/gridfm_graphkit/datasets/powergrid_datamodule.py +++ b/gridfm_graphkit/datasets/powergrid_datamodule.py @@ -10,6 +10,9 @@ ) from gridfm_graphkit.datasets.utils import split_dataset from gridfm_graphkit.datasets.powergrid_dataset import GridDatasetDisk + +from gridfm_graphkit.datasets.posenc_stats import ComputePosencStat + import numpy as np import random import warnings @@ -129,6 +132,17 @@ def setup(self, stage: str): mask_dim=self.args.data.mask_dim, transform=get_transform(args=self.args), ) + + if self.args.data.posenc_RRWP.enable: + pe_transform = ComputePosencStat(pe_types=pe_enabled_list, # TODO connect arguments + is_undirected=is_undirected, + cfg=cfg + ) + if dataset.transform is None: + dataset.transform = pe_transform + else: + dataset.transform = T.compose([pe_transform, dataset.transform]) + self.datasets.append(dataset) num_scenarios = self.args.data.scenarios[i] diff --git a/gridfm_graphkit/datasets/rrwp.py b/gridfm_graphkit/datasets/rrwp.py new file mode 100644 index 0000000..d88e3e7 --- /dev/null +++ b/gridfm_graphkit/datasets/rrwp.py @@ -0,0 +1,103 @@ +# ------------------------ : new rwpse ---------------- +from typing import Union, Any, Optional +import numpy as np +import torch +import torch.nn.functional as F +import torch_geometric as pyg +from torch_geometric.data import Data, HeteroData +from torch_geometric.transforms import BaseTransform +from torch_scatter import scatter, scatter_add, scatter_max + +from torch_geometric.graphgym.config import cfg + +from torch_geometric.utils import ( + get_laplacian, + get_self_loop_attr, + to_scipy_sparse_matrix, +) +import torch_sparse +from torch_sparse import SparseTensor + + +def add_node_attr(data: Data, value: Any, + attr_name: Optional[str] = None) -> Data: + if attr_name is None: + if 'x' in data: + x = data.x.view(-1, 1) if data.x.dim() == 1 else data.x + data.x = torch.cat([x, value.to(x.device, x.dtype)], dim=-1) + else: + data.x = value + else: + data[attr_name] = value + + return data + + + +@torch.no_grad() +def add_full_rrwp(data, + walk_length=8, + attr_name_abs="rrwp", # name: 'rrwp' + attr_name_rel="rrwp", # name: ('rrwp_idx', 'rrwp_val') + add_identity=True, + spd=False, + **kwargs + ): + device=data.edge_index.device + ind_vec = torch.eye(walk_length, dtype=torch.float, device=device) + num_nodes = data.num_nodes + edge_index, edge_weight = data.edge_index, data.edge_weight + + adj = SparseTensor.from_edge_index(edge_index, edge_weight, + sparse_sizes=(num_nodes, num_nodes), + ) + + # Compute D^{-1} A: + deg = adj.sum(dim=1) + deg_inv = 1.0 / adj.sum(dim=1) + deg_inv[deg_inv == float('inf')] = 0 + adj = adj * deg_inv.view(-1, 1) + adj = adj.to_dense() + + pe_list = [] + i = 0 + if add_identity: + pe_list.append(torch.eye(num_nodes, dtype=torch.float)) + i = i + 1 + + out = adj + pe_list.append(adj) + + if walk_length > 2: + for j in range(i + 1, walk_length): + out = out @ adj + pe_list.append(out) + + pe = torch.stack(pe_list, dim=-1) # n x n x k + + abs_pe = pe.diagonal().transpose(0, 1) # n x k + + rel_pe = SparseTensor.from_dense(pe, has_value=True) + rel_pe_row, rel_pe_col, rel_pe_val = rel_pe.coo() + # rel_pe_idx = torch.stack([rel_pe_row, rel_pe_col], dim=0) + rel_pe_idx = torch.stack([rel_pe_col, rel_pe_row], dim=0) + # the framework of GRIT performing right-mul while adj is row-normalized, + # need to switch the order or row and col. + # note: both can work but the current version is more reasonable. + + + if spd: + spd_idx = walk_length - torch.arange(walk_length) + val = (rel_pe_val > 0).type(torch.float) * spd_idx.unsqueeze(0) + val = torch.argmax(val, dim=-1) + rel_pe_val = F.one_hot(val, walk_length).type(torch.float) + abs_pe = torch.zeros_like(abs_pe) + + data = add_node_attr(data, abs_pe, attr_name=attr_name_abs) + data = add_node_attr(data, rel_pe_idx, attr_name=f"{attr_name_rel}_index") + data = add_node_attr(data, rel_pe_val, attr_name=f"{attr_name_rel}_val") + data.log_deg = torch.log(deg + 1) + data.deg = deg.type(torch.long) + + return data + From a7bd51d1747ae3dd6f12a712b2962484abc59d6f Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:38 -0500 Subject: [PATCH 06/62] clean up Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/datasets/posenc_stats.py | 372 +----------------- .../datasets/powergrid_datamodule.py | 5 +- 2 files changed, 9 insertions(+), 368 deletions(-) diff --git a/gridfm_graphkit/datasets/posenc_stats.py b/gridfm_graphkit/datasets/posenc_stats.py index 492a0a6..049633c 100644 --- a/gridfm_graphkit/datasets/posenc_stats.py +++ b/gridfm_graphkit/datasets/posenc_stats.py @@ -11,8 +11,10 @@ from functools import partial from gridfm_graphkit.datasets.rrwp import add_full_rrwp +from torch_geometric.transforms import BaseTransform +from torch_geometric.data import Data -def compute_posenc_stats(data, pe_types, is_undirected, cfg): +def compute_posenc_stats(data, pe_types, cfg): """Precompute positional encodings for the given graph. Supported PE statistics to precompute, selected by `pe_types`: 'LapPE': Laplacian eigen-decomposition. @@ -37,387 +39,27 @@ def compute_posenc_stats(data, pe_types, is_undirected, cfg): 'RWSE', 'HKdiagSE', 'HKfullPE', 'ElstaticSE','RRWP']: raise ValueError(f"Unexpected PE stats selection {t} in {pe_types}") - # Basic preprocessing of the input graph. - if hasattr(data, 'num_nodes'): - N = data.num_nodes # Explicitly given number of nodes, e.g. ogbg-ppa - else: - N = data.x.shape[0] # Number of nodes, including disconnected nodes. - laplacian_norm_type = cfg.posenc_LapPE.eigen.laplacian_norm.lower() - if laplacian_norm_type == 'none': - laplacian_norm_type = None - if is_undirected: - undir_edge_index = data.edge_index - else: - undir_edge_index = to_undirected(data.edge_index) - - # Eigen values and vectors. - evals, evects = None, None - if 'LapPE' in pe_types or 'EquivStableLapPE' in pe_types: - # Eigen-decomposition with numpy, can be reused for Heat kernels. - L = to_scipy_sparse_matrix( - *get_laplacian(undir_edge_index, normalization=laplacian_norm_type, - num_nodes=N) - ) - evals, evects = np.linalg.eigh(L.toarray()) - - if 'LapPE' in pe_types: - max_freqs=cfg.posenc_LapPE.eigen.max_freqs - eigvec_norm=cfg.posenc_LapPE.eigen.eigvec_norm - elif 'EquivStableLapPE' in pe_types: - max_freqs=cfg.posenc_EquivStableLapPE.eigen.max_freqs - eigvec_norm=cfg.posenc_EquivStableLapPE.eigen.eigvec_norm - - data.EigVals, data.EigVecs = get_lap_decomp_stats( - evals=evals, evects=evects, - max_freqs=max_freqs, - eigvec_norm=eigvec_norm) - - if 'SignNet' in pe_types: - # Eigen-decomposition with numpy for SignNet. - norm_type = cfg.posenc_SignNet.eigen.laplacian_norm.lower() - if norm_type == 'none': - norm_type = None - L = to_scipy_sparse_matrix( - *get_laplacian(undir_edge_index, normalization=norm_type, - num_nodes=N) - ) - evals_sn, evects_sn = np.linalg.eigh(L.toarray()) - data.eigvals_sn, data.eigvecs_sn = get_lap_decomp_stats( - evals=evals_sn, evects=evects_sn, - max_freqs=cfg.posenc_SignNet.eigen.max_freqs, - eigvec_norm=cfg.posenc_SignNet.eigen.eigvec_norm) - - # Random Walks. - if 'RWSE' in pe_types: - kernel_param = cfg.posenc_RWSE.kernel - if len(kernel_param.times) == 0: - raise ValueError("List of kernel times required for RWSE") - rw_landing = get_rw_landing_probs(ksteps=kernel_param.times, - edge_index=data.edge_index, - num_nodes=N) - data.pestat_RWSE = rw_landing - - # Heat Kernels. - if 'HKdiagSE' in pe_types or 'HKfullPE' in pe_types: - # Get the eigenvalues and eigenvectors of the regular Laplacian, - # if they have not yet been computed for 'eigen'. - if laplacian_norm_type is not None or evals is None or evects is None: - L_heat = to_scipy_sparse_matrix( - *get_laplacian(undir_edge_index, normalization=None, num_nodes=N) - ) - evals_heat, evects_heat = np.linalg.eigh(L_heat.toarray()) - else: - evals_heat, evects_heat = evals, evects - evals_heat = torch.from_numpy(evals_heat) - evects_heat = torch.from_numpy(evects_heat) - - # Get the full heat kernels. - if 'HKfullPE' in pe_types: - # The heat kernels can't be stored in the Data object without - # additional padding because in PyG's collation of the graphs the - # sizes of tensors must match except in dimension 0. Do this when - # the full heat kernels are actually used downstream by an Encoder. - raise NotImplementedError() - # heat_kernels, hk_diag = get_heat_kernels(evects_heat, evals_heat, - # kernel_times=kernel_param.times) - # data.pestat_HKdiagSE = hk_diag - # Get heat kernel diagonals in more efficient way. - if 'HKdiagSE' in pe_types: - kernel_param = cfg.posenc_HKdiagSE.kernel - if len(kernel_param.times) == 0: - raise ValueError("Diffusion times are required for heat kernel") - hk_diag = get_heat_kernels_diag(evects_heat, evals_heat, - kernel_times=kernel_param.times, - space_dim=0) - data.pestat_HKdiagSE = hk_diag - - # Electrostatic interaction inspired kernel. - if 'ElstaticSE' in pe_types: - elstatic = get_electrostatic_function_encoding(undir_edge_index, N) - data.pestat_ElstaticSE = elstatic - if 'RRWP' in pe_types: param = cfg.posenc_RRWP transform = partial(add_full_rrwp, walk_length=param.ksteps, attr_name_abs="rrwp", attr_name_rel="rrwp", - add_identity=True, - spd=param.spd, # by default False + add_identity=True ) data = transform(data) return data -def get_lap_decomp_stats(evals, evects, max_freqs, eigvec_norm='L2'): - """Compute Laplacian eigen-decomposition-based PE stats of the given graph. - - Args: - evals, evects: Precomputed eigen-decomposition - max_freqs: Maximum number of top smallest frequencies / eigenvecs to use - eigvec_norm: Normalization for the eigen vectors of the Laplacian - Returns: - Tensor (num_nodes, max_freqs, 1) eigenvalues repeated for each node - Tensor (num_nodes, max_freqs) of eigenvector values per node - """ - N = len(evals) # Number of nodes, including disconnected nodes. - - # Keep up to the maximum desired number of frequencies. - idx = evals.argsort()[:max_freqs] - evals, evects = evals[idx], np.real(evects[:, idx]) - evals = torch.from_numpy(np.real(evals)).clamp_min(0) - - # Normalize and pad eigen vectors. - evects = torch.from_numpy(evects).float() - evects = eigvec_normalizer(evects, evals, normalization=eigvec_norm) - if N < max_freqs: - EigVecs = F.pad(evects, (0, max_freqs - N), value=float('nan')) - else: - EigVecs = evects - - # Pad and save eigenvalues. - if N < max_freqs: - EigVals = F.pad(evals, (0, max_freqs - N), value=float('nan')).unsqueeze(0) - else: - EigVals = evals.unsqueeze(0) - EigVals = EigVals.repeat(N, 1).unsqueeze(2) - - return EigVals, EigVecs - - -def get_rw_landing_probs(ksteps, edge_index, edge_weight=None, - num_nodes=None, space_dim=0): - """Compute Random Walk landing probabilities for given list of K steps. - - Args: - ksteps: List of k-steps for which to compute the RW landings - edge_index: PyG sparse representation of the graph - edge_weight: (optional) Edge weights - num_nodes: (optional) Number of nodes in the graph - space_dim: (optional) Estimated dimensionality of the space. Used to - correct the random-walk diagonal by a factor `k^(space_dim/2)`. - In euclidean space, this correction means that the height of - the gaussian distribution stays almost constant across the number of - steps, if `space_dim` is the dimension of the euclidean space. - - Returns: - 2D Tensor with shape (num_nodes, len(ksteps)) with RW landing probs - """ - if edge_weight is None: - edge_weight = torch.ones(edge_index.size(1), device=edge_index.device) - num_nodes = maybe_num_nodes(edge_index, num_nodes) - source, dest = edge_index[0], edge_index[1] - deg = scatter_add(edge_weight, source, dim=0, dim_size=num_nodes) # Out degrees. - deg_inv = deg.pow(-1.) - deg_inv.masked_fill_(deg_inv == float('inf'), 0) - - if edge_index.numel() == 0: - P = edge_index.new_zeros((1, num_nodes, num_nodes)) - else: - # P = D^-1 * A - P = torch.diag(deg_inv) @ to_dense_adj(edge_index, max_num_nodes=num_nodes) # 1 x (Num nodes) x (Num nodes) - rws = [] - if ksteps == list(range(min(ksteps), max(ksteps) + 1)): - # Efficient way if ksteps are a consecutive sequence (most of the time the case) - Pk = P.clone().detach().matrix_power(min(ksteps)) - for k in range(min(ksteps), max(ksteps) + 1): - rws.append(torch.diagonal(Pk, dim1=-2, dim2=-1) * \ - (k ** (space_dim / 2))) - Pk = Pk @ P - else: - # Explicitly raising P to power k for each k \in ksteps. - for k in ksteps: - rws.append(torch.diagonal(P.matrix_power(k), dim1=-2, dim2=-1) * \ - (k ** (space_dim / 2))) - rw_landing = torch.cat(rws, dim=0).transpose(0, 1) # (Num nodes) x (K steps) - - return rw_landing - - -def get_heat_kernels_diag(evects, evals, kernel_times=[], space_dim=0): - """Compute Heat kernel diagonal. - - This is a continuous function that represents a Gaussian in the Euclidean - space, and is the solution to the diffusion equation. - The random-walk diagonal should converge to this. - - Args: - evects: Eigenvectors of the Laplacian matrix - evals: Eigenvalues of the Laplacian matrix - kernel_times: Time for the diffusion. Analogous to the k-steps in random - walk. The time is equivalent to the variance of the kernel. - space_dim: (optional) Estimated dimensionality of the space. Used to - correct the diffusion diagonal by a factor `t^(space_dim/2)`. In - euclidean space, this correction means that the height of the - gaussian stays constant across time, if `space_dim` is the dimension - of the euclidean space. - - Returns: - 2D Tensor with shape (num_nodes, len(ksteps)) with RW landing probs - """ - heat_kernels_diag = [] - if len(kernel_times) > 0: - evects = F.normalize(evects, p=2., dim=0) - - # Remove eigenvalues == 0 from the computation of the heat kernel - idx_remove = evals < 1e-8 - evals = evals[~idx_remove] - evects = evects[:, ~idx_remove] - - # Change the shapes for the computations - evals = evals.unsqueeze(-1) # lambda_{i, ..., ...} - evects = evects.transpose(0, 1) # phi_{i,j}: i-th eigvec X j-th node - - # Compute the heat kernels diagonal only for each time - eigvec_mul = evects ** 2 - for t in kernel_times: - # sum_{i>0}(exp(-2 t lambda_i) * phi_{i, j} * phi_{i, j}) - this_kernel = torch.sum(torch.exp(-t * evals) * eigvec_mul, - dim=0, keepdim=False) - - # Multiply by `t` to stabilize the values, since the gaussian height - # is proportional to `1/t` - heat_kernels_diag.append(this_kernel * (t ** (space_dim / 2))) - heat_kernels_diag = torch.stack(heat_kernels_diag, dim=0).transpose(0, 1) - - return heat_kernels_diag - - -def get_heat_kernels(evects, evals, kernel_times=[]): - """Compute full Heat diffusion kernels. - - Args: - evects: Eigenvectors of the Laplacian matrix - evals: Eigenvalues of the Laplacian matrix - kernel_times: Time for the diffusion. Analogous to the k-steps in random - walk. The time is equivalent to the variance of the kernel. - """ - heat_kernels, rw_landing = [], [] - if len(kernel_times) > 0: - evects = F.normalize(evects, p=2., dim=0) - - # Remove eigenvalues == 0 from the computation of the heat kernel - idx_remove = evals < 1e-8 - evals = evals[~idx_remove] - evects = evects[:, ~idx_remove] - - # Change the shapes for the computations - evals = evals.unsqueeze(-1).unsqueeze(-1) # lambda_{i, ..., ...} - evects = evects.transpose(0, 1) # phi_{i,j}: i-th eigvec X j-th node - - # Compute the heat kernels for each time - eigvec_mul = (evects.unsqueeze(2) * evects.unsqueeze(1)) # (phi_{i, j1, ...} * phi_{i, ..., j2}) - for t in kernel_times: - # sum_{i>0}(exp(-2 t lambda_i) * phi_{i, j1, ...} * phi_{i, ..., j2}) - heat_kernels.append( - torch.sum(torch.exp(-t * evals) * eigvec_mul, - dim=0, keepdim=False) - ) - - heat_kernels = torch.stack(heat_kernels, dim=0) # (Num kernel times) x (Num nodes) x (Num nodes) - - # Take the diagonal of each heat kernel, - # i.e. the landing probability of each of the random walks - rw_landing = torch.diagonal(heat_kernels, dim1=-2, dim2=-1).transpose(0, 1) # (Num nodes) x (Num kernel times) - - return heat_kernels, rw_landing - - -def get_electrostatic_function_encoding(edge_index, num_nodes): - """Kernel based on the electrostatic interaction between nodes. - """ - L = to_scipy_sparse_matrix( - *get_laplacian(edge_index, normalization=None, num_nodes=num_nodes) - ).todense() - L = torch.as_tensor(L) - Dinv = torch.eye(L.shape[0]) * (L.diag() ** -1) - A = deepcopy(L).abs() - A.fill_diagonal_(0) - DinvA = Dinv.matmul(A) - - electrostatic = torch.pinverse(L) - electrostatic = electrostatic - electrostatic.diag() - green_encoding = torch.stack([ - electrostatic.min(dim=0)[0], # Min of Vi -> j - electrostatic.max(dim=0)[0], # Max of Vi -> j - electrostatic.mean(dim=0), # Mean of Vi -> j - electrostatic.std(dim=0), # Std of Vi -> j - electrostatic.min(dim=1)[0], # Min of Vj -> i - electrostatic.max(dim=0)[0], # Max of Vj -> i - electrostatic.mean(dim=1), # Mean of Vj -> i - electrostatic.std(dim=1), # Std of Vj -> i - (DinvA * electrostatic).sum(dim=0), # Mean of interaction on direct neighbour - (DinvA * electrostatic).sum(dim=1), # Mean of interaction from direct neighbour - ], dim=1) - - return green_encoding - - -def eigvec_normalizer(EigVecs, EigVals, normalization="L2", eps=1e-12): - """ - Implement different eigenvector normalizations. - """ - - EigVals = EigVals.unsqueeze(0) - - if normalization == "L1": - # L1 normalization: eigvec / sum(abs(eigvec)) - denom = EigVecs.norm(p=1, dim=0, keepdim=True) - - elif normalization == "L2": - # L2 normalization: eigvec / sqrt(sum(eigvec^2)) - denom = EigVecs.norm(p=2, dim=0, keepdim=True) - - elif normalization == "abs-max": - # AbsMax normalization: eigvec / max|eigvec| - denom = torch.max(EigVecs.abs(), dim=0, keepdim=True).values - - elif normalization == "wavelength": - # AbsMax normalization, followed by wavelength multiplication: - # eigvec * pi / (2 * max|eigvec| * sqrt(eigval)) - denom = torch.max(EigVecs.abs(), dim=0, keepdim=True).values - eigval_denom = torch.sqrt(EigVals) - eigval_denom[EigVals < eps] = 1 # Problem with eigval = 0 - denom = denom * eigval_denom * 2 / np.pi - - elif normalization == "wavelength-asin": - # AbsMax normalization, followed by arcsin and wavelength multiplication: - # arcsin(eigvec / max|eigvec|) / sqrt(eigval) - denom_temp = torch.max(EigVecs.abs(), dim=0, keepdim=True).values.clamp_min(eps).expand_as(EigVecs) - EigVecs = torch.asin(EigVecs / denom_temp) - eigval_denom = torch.sqrt(EigVals) - eigval_denom[EigVals < eps] = 1 # Problem with eigval = 0 - denom = eigval_denom - - elif normalization == "wavelength-soft": - # AbsSoftmax normalization, followed by wavelength multiplication: - # eigvec / (softmax|eigvec| * sqrt(eigval)) - denom = (F.softmax(EigVecs.abs(), dim=0) * EigVecs.abs()).sum(dim=0, keepdim=True) - eigval_denom = torch.sqrt(EigVals) - eigval_denom[EigVals < eps] = 1 # Problem with eigval = 0 - denom = denom * eigval_denom - - else: - raise ValueError(f"Unsupported normalization `{normalization}`") - - denom = denom.clamp_min(eps).expand_as(EigVecs) - EigVecs = EigVecs / denom - - return EigVecs - -from torch_geometric.transforms import BaseTransform -from torch_geometric.data import Data, HeteroData - class ComputePosencStat(BaseTransform): - def __init__(self, pe_types, is_undirected, cfg): + def __init__(self, pe_types, cfg): self.pe_types = pe_types - self.is_undirected = is_undirected self.cfg = cfg def __call__(self, data: Data) -> Data: - data = compute_posenc_stats(data, pe_types=self.pe_types, - is_undirected=self.is_undirected, + data = compute_posenc_stats(data, + pe_types=self.pe_types, cfg=self.cfg ) return data \ No newline at end of file diff --git a/gridfm_graphkit/datasets/powergrid_datamodule.py b/gridfm_graphkit/datasets/powergrid_datamodule.py index ad68f4f..d73f6fa 100644 --- a/gridfm_graphkit/datasets/powergrid_datamodule.py +++ b/gridfm_graphkit/datasets/powergrid_datamodule.py @@ -134,9 +134,8 @@ def setup(self, stage: str): ) if self.args.data.posenc_RRWP.enable: - pe_transform = ComputePosencStat(pe_types=pe_enabled_list, # TODO connect arguments - is_undirected=is_undirected, - cfg=cfg + pe_transform = ComputePosencStat(pe_types=['RRWP'], + cfg=self.args.data ) if dataset.transform is None: dataset.transform = pe_transform From 226f2a3098c4ed4499ac982d53f1e49338cce9fd Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:38 -0500 Subject: [PATCH 07/62] matching up parameters Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/grit_pretraining.yaml | 1 + gridfm_graphkit/datasets/posenc_stats.py | 3 +- .../datasets/powergrid_datamodule.py | 2 + gridfm_graphkit/datasets/rrwp.py | 1 - gridfm_graphkit/models/__init__.py | 3 +- gridfm_graphkit/models/grit_layer.py | 3 +- gridfm_graphkit/models/grit_transformer.py | 42 +++++++++---------- 7 files changed, 28 insertions(+), 27 deletions(-) diff --git a/examples/config/grit_pretraining.yaml b/examples/config/grit_pretraining.yaml index 05f31be..9896057 100644 --- a/examples/config/grit_pretraining.yaml +++ b/examples/config/grit_pretraining.yaml @@ -48,6 +48,7 @@ model: edge_encoder: True node_encoder_name: TODO node_encoder_bn: True + .edge_encoder_bn: True gt: layer_type: GritTransformer dim_hidden: 64 # `gt.dim_hidden` must match `gnn.dim_inner` diff --git a/gridfm_graphkit/datasets/posenc_stats.py b/gridfm_graphkit/datasets/posenc_stats.py index 049633c..8bb2b9d 100644 --- a/gridfm_graphkit/datasets/posenc_stats.py +++ b/gridfm_graphkit/datasets/posenc_stats.py @@ -16,7 +16,8 @@ def compute_posenc_stats(data, pe_types, cfg): """Precompute positional encodings for the given graph. - Supported PE statistics to precompute, selected by `pe_types`: + Supported PE statistics to precompute in original implementation, + selected by `pe_types`: 'LapPE': Laplacian eigen-decomposition. 'RWSE': Random walk landing probabilities (diagonals of RW matrices). 'HKfullPE': Full heat kernels and their diagonals. (NOT IMPLEMENTED) diff --git a/gridfm_graphkit/datasets/powergrid_datamodule.py b/gridfm_graphkit/datasets/powergrid_datamodule.py index d73f6fa..9960e08 100644 --- a/gridfm_graphkit/datasets/powergrid_datamodule.py +++ b/gridfm_graphkit/datasets/powergrid_datamodule.py @@ -13,6 +13,8 @@ from gridfm_graphkit.datasets.posenc_stats import ComputePosencStat +import torch_geometric.transforms as T + import numpy as np import random import warnings diff --git a/gridfm_graphkit/datasets/rrwp.py b/gridfm_graphkit/datasets/rrwp.py index d88e3e7..26218f0 100644 --- a/gridfm_graphkit/datasets/rrwp.py +++ b/gridfm_graphkit/datasets/rrwp.py @@ -8,7 +8,6 @@ from torch_geometric.transforms import BaseTransform from torch_scatter import scatter, scatter_add, scatter_max -from torch_geometric.graphgym.config import cfg from torch_geometric.utils import ( get_laplacian, diff --git a/gridfm_graphkit/models/__init__.py b/gridfm_graphkit/models/__init__.py index de355d3..cc66936 100644 --- a/gridfm_graphkit/models/__init__.py +++ b/gridfm_graphkit/models/__init__.py @@ -1,4 +1,5 @@ from gridfm_graphkit.models.gps_transformer import GPSTransformer from gridfm_graphkit.models.gnn_transformer import GNN_TransformerConv +from gridfm_graphkit.models.grit_transformer import GritTransformer -__all__ = ["GPSTransformer", "GNN_TransformerConv"] +__all__ = ["GPSTransformer", "GNN_TransformerConv", "GRIT"] diff --git a/gridfm_graphkit/models/grit_layer.py b/gridfm_graphkit/models/grit_layer.py index b477980..53e7217 100644 --- a/gridfm_graphkit/models/grit_layer.py +++ b/gridfm_graphkit/models/grit_layer.py @@ -8,7 +8,6 @@ import opt_einsum as oe -from yacs.config import CfgNode as CN import warnings @@ -48,7 +47,7 @@ def __init__(self, in_dim, out_dim, num_heads, use_bias, edge_enhance=True, sqrt_relu=False, signed_sqrt=True, - cfg=CN(), + cfg={}, **kwargs): super().__init__() diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index 715c25f..49bfdf2 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -1,8 +1,9 @@ from gridfm_graphkit.io.registries import MODELS_REGISTRY import torch from torch import nn -from rrwp_encoder import RRWPLinearNodeEncoder, RRWPLinearEdgeEncoder -from grit_layer import GritTransformerLayer + +from gridfm_graphkit.models.rrwp_encoder import RRWPLinearNodeEncoder, RRWPLinearEdgeEncoder +from gridfm_graphkit.models.grit_layer import GritTransformerLayer @@ -27,25 +28,21 @@ def forward(self, batch): class LinearNodeEncoder(torch.nn.Module): - def __init__(self, emb_dim): + def __init__(self, dim_in, emb_dim): super().__init__() - self.encoder = torch.nn.Linear(cfg.share.dim_in, emb_dim) + self.encoder = torch.nn.Linear(dim_in, emb_dim) def forward(self, batch): batch.x = self.encoder(batch.x) return batch class LinearEdgeEncoder(torch.nn.Module): - def __init__(self, emb_dim): + def __init__(self, edge_dim, emb_dim): super().__init__() - if cfg.dataset.name in ['MNIST', 'CIFAR10']: - self.in_dim = 1 - elif cfg.dataset.name.startswith('attributed_triangle-'): - self.in_dim = 2 - else: - raise ValueError("Input edge feature dim is required to be hardset " - "or refactored to use a cfg option.") + + self.in_dim = edge_dim + self.encoder = torch.nn.Linear(self.in_dim, emb_dim) def forward(self, batch): @@ -69,20 +66,20 @@ def __init__( ): super(FeatureEncoder, self).__init__() self.dim_in = dim_in - if args.node_encoder: + if args.encoder.node_encoder: # Encode integer node features via nn.Embeddings - self.node_encoder = LinearNodeEncoder(dim_inner) - if args.node_encoder_bn: + self.node_encoder = LinearNodeEncoder(self.dim_in, dim_inner) + if args.encoder.node_encoder_bn: self.node_encoder_bn = BatchNorm1dNode(dim_inner, 1e-5, 0.1) # Update dim_in to reflect the new dimension fo the node features self.dim_in = dim_inner - if args.edge_encoder: - - dim_edge = dim_inner + if args.encoder.edge_encoder: + args.edge_dim + enc_dim_edge = dim_inner # Encode integer edge features via nn.Embeddings - self.edge_encoder = LinearEdgeEncoder(dim_edge) - if cfg.dataset.edge_encoder_bn: - self.edge_encoder_bn = BatchNorm1dNode(dim_edge, 1e-5, 0.1) + self.edge_encoder = LinearEdgeEncoder(edge_dim, enc_dim_edge) + if args.encoder.edge_encoder_bn: + self.edge_encoder_bn = BatchNorm1dNode(enc_dim_edge, 1e-5, 0.1) def forward(self, batch): for module in self.children(): @@ -121,7 +118,7 @@ def __init__(self, args): dim_inner ) rel_pe_dim = args.model.posenc_RRWP.ksteps - self.rrwp_rel_encoder = RRWPLinearNodeEncoder( + self.rrwp_rel_encoder = RRWPLinearEdgeEncoder( rel_pe_dim, dim_edge, pad_to_full_graph=args.model.gt.attn.full_attn, @@ -158,6 +155,7 @@ def __init__(self, args): ) def forward(self, batch): + print('process--->>', batch) # TODO remove print for module in self.children(): batch = module(batch) From 88d9ca6a7d2cea3f6f935b123b236a203452809f Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:38 -0500 Subject: [PATCH 08/62] matching up parameters Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/grit_pretraining.yaml | 2 +- gridfm_graphkit/models/grit_transformer.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/config/grit_pretraining.yaml b/examples/config/grit_pretraining.yaml index 9896057..bd76278 100644 --- a/examples/config/grit_pretraining.yaml +++ b/examples/config/grit_pretraining.yaml @@ -48,7 +48,7 @@ model: edge_encoder: True node_encoder_name: TODO node_encoder_bn: True - .edge_encoder_bn: True + edge_encoder_bn: True gt: layer_type: GritTransformer dim_hidden: 64 # `gt.dim_hidden` must match `gnn.dim_inner` diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index 49bfdf2..e8746b8 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -74,7 +74,7 @@ def __init__( # Update dim_in to reflect the new dimension fo the node features self.dim_in = dim_inner if args.encoder.edge_encoder: - args.edge_dim + edge_dim = args.edge_dim enc_dim_edge = dim_inner # Encode integer edge features via nn.Embeddings self.edge_encoder = LinearEdgeEncoder(edge_dim, enc_dim_edge) @@ -107,8 +107,8 @@ def __init__(self, args): self.encoder = FeatureEncoder( dim_in, dim_inner, - args.model.encoder - ) # TODO add args + args.model + ) dim_in = self.encoder.dim_in if args.model.posenc_RRWP.enable: From b7d9dcf035026c797c1d68577d8b04ed9d3247ee Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:39 -0500 Subject: [PATCH 09/62] matching up parameters Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/grit_pretraining.yaml | 2 ++ gridfm_graphkit/models/grit_layer.py | 4 ++-- gridfm_graphkit/models/grit_transformer.py | 6 +++--- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/examples/config/grit_pretraining.yaml b/examples/config/grit_pretraining.yaml index bd76278..73e5817 100644 --- a/examples/config/grit_pretraining.yaml +++ b/examples/config/grit_pretraining.yaml @@ -64,6 +64,8 @@ model: O_e: True norm_e: True signed_sqrt: True + bn_momentum: 0.1 + bn_no_runner: False optimizer: beta1: 0.9 diff --git a/gridfm_graphkit/models/grit_layer.py b/gridfm_graphkit/models/grit_layer.py index 53e7217..ffcf584 100644 --- a/gridfm_graphkit/models/grit_layer.py +++ b/gridfm_graphkit/models/grit_layer.py @@ -166,10 +166,10 @@ def __init__(self, in_dim, out_dim, num_heads, self.batch_norm = batch_norm # ------- - self.update_e = cfg.get("update_e", True) + self.update_e = getattr(cfg, "update_e", True) self.bn_momentum = cfg.bn_momentum self.bn_no_runner = cfg.bn_no_runner - self.rezero = cfg.get("rezero", False) + self.rezero = getattr(cfg, "rezero", False) self.act = act_dict[act]() if act is not None else nn.Identity() if cfg.get("attn", None) is None: diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index e8746b8..8d4f696 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -111,13 +111,13 @@ def __init__(self, args): ) dim_in = self.encoder.dim_in - if args.model.posenc_RRWP.enable: + if args.data.posenc_RRWP.enable: self.rrwp_abs_encoder = RRWPLinearNodeEncoder( - args.model.posenc_RRWP.ksteps, + args.data.posenc_RRWP.ksteps, dim_inner ) - rel_pe_dim = args.model.posenc_RRWP.ksteps + rel_pe_dim = args.data.posenc_RRWP.ksteps self.rrwp_rel_encoder = RRWPLinearEdgeEncoder( rel_pe_dim, dim_edge, From 38cc44a31d909cba09a8296784fc5f02e9637dba Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:39 -0500 Subject: [PATCH 10/62] matching up parameters Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/models/grit_layer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/gridfm_graphkit/models/grit_layer.py b/gridfm_graphkit/models/grit_layer.py index ffcf584..98d0b6c 100644 --- a/gridfm_graphkit/models/grit_layer.py +++ b/gridfm_graphkit/models/grit_layer.py @@ -166,10 +166,10 @@ def __init__(self, in_dim, out_dim, num_heads, self.batch_norm = batch_norm # ------- - self.update_e = getattr(cfg, "update_e", True) - self.bn_momentum = cfg.bn_momentum - self.bn_no_runner = cfg.bn_no_runner - self.rezero = getattr(cfg, "rezero", False) + self.update_e = getattr(cfg.attn, "update_e", True) + self.bn_momentum = cfg.attn.bn_momentum + self.bn_no_runner = cfg.attn.bn_no_runner + self.rezero = getattr(cfg.attn, "rezero", False) self.act = act_dict[act]() if act is not None else nn.Identity() if cfg.get("attn", None) is None: From 7fded9556996acec2d55919e95059fe4faba93fa Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:39 -0500 Subject: [PATCH 11/62] matching up parameters in grit layer Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/models/grit_layer.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/gridfm_graphkit/models/grit_layer.py b/gridfm_graphkit/models/grit_layer.py index 98d0b6c..9723304 100644 --- a/gridfm_graphkit/models/grit_layer.py +++ b/gridfm_graphkit/models/grit_layer.py @@ -72,7 +72,7 @@ def __init__(self, in_dim, out_dim, num_heads, use_bias, if act is None: self.act = nn.Identity() else: - self.act = act_dict[act]() + self.act = nn.ReLU() if self.edge_enhance: self.VeRow = nn.Parameter(torch.zeros(self.out_dim, self.num_heads, self.out_dim), requires_grad=True) @@ -171,12 +171,15 @@ def __init__(self, in_dim, out_dim, num_heads, self.bn_no_runner = cfg.attn.bn_no_runner self.rezero = getattr(cfg.attn, "rezero", False) - self.act = act_dict[act]() if act is not None else nn.Identity() - if cfg.get("attn", None) is None: + if act is not None + self.act = nn.ReLU() + else: + self.act = nn.Identity() + + if getattr(cfg, "attn", None) is None: cfg.attn = dict() - self.use_attn = cfg.attn.get("use", True) - # self.sigmoid_deg = cfg.attn.get("sigmoid_deg", False) - self.deg_scaler = cfg.attn.get("deg_scaler", True) + self.use_attn = getattr(cfg.attn, "use", True) + self.deg_scaler = getattr(cfg.attn, "deg_scaler", True) self.attention = MultiHeadAttentionLayerGritSparse( in_dim=in_dim, From 0f3b803a2a2ccfa69b11bac643b1ad14e2e24d38 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:39 -0500 Subject: [PATCH 12/62] matching up parameters in grit layer Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/models/grit_layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gridfm_graphkit/models/grit_layer.py b/gridfm_graphkit/models/grit_layer.py index 9723304..0bcdf73 100644 --- a/gridfm_graphkit/models/grit_layer.py +++ b/gridfm_graphkit/models/grit_layer.py @@ -171,7 +171,7 @@ def __init__(self, in_dim, out_dim, num_heads, self.bn_no_runner = cfg.attn.bn_no_runner self.rezero = getattr(cfg.attn, "rezero", False) - if act is not None + if act is not None: self.act = nn.ReLU() else: self.act = nn.Identity() From af8ad03ad589f781a35d5deaa557a616c5463e80 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:40 -0500 Subject: [PATCH 13/62] matching up parameters in grit layer Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/models/grit_layer.py | 38 ++++++++++++++-------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/gridfm_graphkit/models/grit_layer.py b/gridfm_graphkit/models/grit_layer.py index 0bcdf73..f95ffc7 100644 --- a/gridfm_graphkit/models/grit_layer.py +++ b/gridfm_graphkit/models/grit_layer.py @@ -185,31 +185,31 @@ def __init__(self, in_dim, out_dim, num_heads, in_dim=in_dim, out_dim=out_dim // num_heads, num_heads=num_heads, - use_bias=cfg.attn.get("use_bias", False), + use_bias=getattr(cfg.attn, "use_bias", False), dropout=attn_dropout, - clamp=cfg.attn.get("clamp", 5.), - act=cfg.attn.get("act", "relu"), - edge_enhance=cfg.attn.get("edge_enhance", True), - sqrt_relu=cfg.attn.get("sqrt_relu", False), - signed_sqrt=cfg.attn.get("signed_sqrt", False), - scaled_attn =cfg.attn.get("scaled_attn", False), - no_qk=cfg.attn.get("no_qk", False), + clamp=getattr(cfg.attn, "clamp", 5.), + act=getattr(cfg.attn, "act", "relu"), + edge_enhance=getattr(cfg.attn, "edge_enhance", True), + sqrt_relu=getattr(cfg.attn, "sqrt_relu", False), + signed_sqrt=getattr(cfg.attn, "signed_sqrt", False), + scaled_attn =getattr(cfg.attn,"scaled_attn", False), + no_qk=getattr(cfg.attn, "no_qk", False), ) - if cfg.attn.get('graphormer_attn', False): + if getattr(cfg.attn, 'graphormer_attn', False): self.attention = MultiHeadAttentionLayerGraphormerSparse( in_dim=in_dim, out_dim=out_dim // num_heads, num_heads=num_heads, - use_bias=cfg.attn.get("use_bias", False), + use_bias=getattr(cfg.attn, "use_bias", False), dropout=attn_dropout, - clamp=cfg.attn.get("clamp", 5.), - act=cfg.attn.get("act", "relu"), + clamp=getattr(cfg.attn, "clamp", 5.), + act=getattr(cfg.attn, "act", "relu"), edge_enhance=True, - sqrt_relu=cfg.attn.get("sqrt_relu", False), - signed_sqrt=cfg.attn.get("signed_sqrt", False), - scaled_attn =cfg.attn.get("scaled_attn", False), - no_qk=cfg.attn.get("no_qk", False), + sqrt_relu=getattr(cfg.attn, "sqrt_relu", False), + signed_sqrt=getattr(cfg.attn, "signed_sqrt", False), + scaled_attn =getattr(cfg.attn, "scaled_attn", False), + no_qk=getattr(cfg.attn, "no_qk", False), ) @@ -232,8 +232,8 @@ def __init__(self, in_dim, out_dim, num_heads, if self.batch_norm: # when the batch_size is really small, use smaller momentum to avoid bad mini-batch leading to extremely bad val/test loss (NaN) - self.batch_norm1_h = nn.BatchNorm1d(out_dim, track_running_stats=not self.bn_no_runner, eps=1e-5, momentum=cfg.bn_momentum) - self.batch_norm1_e = nn.BatchNorm1d(out_dim, track_running_stats=not self.bn_no_runner, eps=1e-5, momentum=cfg.bn_momentum) if norm_e else nn.Identity() + self.batch_norm1_h = nn.BatchNorm1d(out_dim, track_running_stats=not self.bn_no_runner, eps=1e-5, momentum=cfg.attn.bn_momentum) + self.batch_norm1_e = nn.BatchNorm1d(out_dim, track_running_stats=not self.bn_no_runner, eps=1e-5, momentum=cfg.attn.bn_momentum) if norm_e else nn.Identity() # FFN for h self.FFN_h_layer1 = nn.Linear(out_dim, out_dim * 2) @@ -243,7 +243,7 @@ def __init__(self, in_dim, out_dim, num_heads, self.layer_norm2_h = nn.LayerNorm(out_dim) if self.batch_norm: - self.batch_norm2_h = nn.BatchNorm1d(out_dim, track_running_stats=not self.bn_no_runner, eps=1e-5, momentum=cfg.bn_momentum) + self.batch_norm2_h = nn.BatchNorm1d(out_dim, track_running_stats=not self.bn_no_runner, eps=1e-5, momentum=cfg.attn.bn_momentum) if self.rezero: self.alpha1_h = nn.Parameter(torch.zeros(1,1)) From f430f2a7562c3db8b570b673422e7406446de593 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:40 -0500 Subject: [PATCH 14/62] matching up parameters in data module Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/datasets/powergrid_datamodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gridfm_graphkit/datasets/powergrid_datamodule.py b/gridfm_graphkit/datasets/powergrid_datamodule.py index 9960e08..e67956e 100644 --- a/gridfm_graphkit/datasets/powergrid_datamodule.py +++ b/gridfm_graphkit/datasets/powergrid_datamodule.py @@ -142,7 +142,7 @@ def setup(self, stage: str): if dataset.transform is None: dataset.transform = pe_transform else: - dataset.transform = T.compose([pe_transform, dataset.transform]) + dataset.transform = T.Compose([pe_transform, dataset.transform]) self.datasets.append(dataset) From e1c489050bd9995e10efe69d6f356515dc6b965d Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:40 -0500 Subject: [PATCH 15/62] flow over parameters from base model Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/models/grit_transformer.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index 8d4f696..10af25a 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -103,6 +103,19 @@ def __init__(self, args): num_heads = args.model.attention_head dropout = args.model.dropout num_layers = args.model.num_layers + self.mask_dim = getattr(args.data, "mask_dim", 6) + self.mask_value = getattr(args.data, "mask_value", -1.0) + self.learn_mask = getattr(args.data, "learn_mask", False) + if self.learn_mask: + self.mask_value = nn.Parameter( + torch.randn(self.mask_dim) + self.mask_value, + requires_grad=True, + ) + else: + self.mask_value = nn.Parameter( + torch.zeros(self.mask_dim) + self.mask_value, + requires_grad=False, + ) self.encoder = FeatureEncoder( dim_in, From 36dca0095cecea391a586b3ec205a229b3511e7c Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:40 -0500 Subject: [PATCH 16/62] verified encodings and data flow to model forward method Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/models/grit_transformer.py | 2 +- .../tasks/feature_reconstruction_task.py | 24 ++++++++++--------- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index 10af25a..cf07a8c 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -167,7 +167,7 @@ def __init__(self, args): nn.Linear(dim_inner, dim_out), ) - def forward(self, batch): + def forward(self, batch): # self, x, pe, edge_index, edge_attr, batch # gps parameters print('process--->>', batch) # TODO remove print for module in self.children(): batch = module(batch) diff --git a/gridfm_graphkit/tasks/feature_reconstruction_task.py b/gridfm_graphkit/tasks/feature_reconstruction_task.py index cb6963b..da2f478 100644 --- a/gridfm_graphkit/tasks/feature_reconstruction_task.py +++ b/gridfm_graphkit/tasks/feature_reconstruction_task.py @@ -74,11 +74,11 @@ def __init__(self, args, node_normalizers, edge_normalizers): self.edge_normalizers = edge_normalizers self.save_hyperparameters() - def forward(self, x, pe, edge_index, edge_attr, batch, mask=None): - if mask is not None: - mask_value_expanded = self.model.mask_value.expand(x.shape[0], -1) - x[:, : mask.shape[1]][mask] = mask_value_expanded[mask] - return self.model(x, pe, edge_index, edge_attr, batch) + def forward(self, batch): + if batch.mask is not None: + mask_value_expanded = self.model.mask_value.expand(batch.x.shape[0], -1) + batch.x[:, : batch.mask.shape[1]][batch.mask] = mask_value_expanded[batch.mask] + return self.model(batch) @rank_zero_only def on_fit_start(self): @@ -111,12 +111,14 @@ def on_fit_start(self): def shared_step(self, batch): output = self.forward( - x=batch.x, - pe=batch.pe, - edge_index=batch.edge_index, - edge_attr=batch.edge_attr, - batch=batch.batch, - mask=batch.mask, + # TODO update args list in the GPS Transf. for consistency + # x=batch.x, + # pe=batch.pe, + # edge_index=batch.edge_index, + # edge_attr=batch.edge_attr, + # batch=batch.batch, + # mask=batch.mask, + batch ) loss_dict = self.loss_fn( From a8ec56efdbef89eef493f3e6f1b897f3e17140b6 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:41 -0500 Subject: [PATCH 17/62] match feature dimensions Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/models/grit_transformer.py | 2 +- gridfm_graphkit/models/rrwp_encoder.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index cf07a8c..50d0fec 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -133,7 +133,7 @@ def __init__(self, args): rel_pe_dim = args.data.posenc_RRWP.ksteps self.rrwp_rel_encoder = RRWPLinearEdgeEncoder( rel_pe_dim, - dim_edge, + dim_inner, pad_to_full_graph=args.model.gt.attn.full_attn, add_node_attr_as_self_loop=False, fill_value=0. diff --git a/gridfm_graphkit/models/rrwp_encoder.py b/gridfm_graphkit/models/rrwp_encoder.py index b73e463..33c5215 100644 --- a/gridfm_graphkit/models/rrwp_encoder.py +++ b/gridfm_graphkit/models/rrwp_encoder.py @@ -114,6 +114,7 @@ def __init__(self, emb_dim, out_dim, batchnorm=False, layernorm=False, use_bias= if self.batchnorm or self.layernorm: warnings.warn("batchnorm/layernorm might ruin some properties of pe on providing shortest-path distance info ") + print('--------fc in and out:', emb_dim, out_dim) self.fc = nn.Linear(emb_dim, out_dim, bias=use_bias) torch.nn.init.xavier_uniform_(self.fc.weight) self.pad_to_full_graph = pad_to_full_graph @@ -144,7 +145,8 @@ def forward(self, batch): else: # edge_index, edge_attr = add_remaining_self_loops(edge_index, edge_attr, num_nodes=batch.num_nodes, fill_value=0.) edge_index, edge_attr = add_self_loops(edge_index, edge_attr, num_nodes=batch.num_nodes, fill_value=0.) - + print('xxxx', edge_attr.size(), rrwp_val.size()) + print('yyyy', edge_index.size(), rrwp_idx.size()) out_idx, out_val = torch_sparse.coalesce( torch.cat([edge_index, rrwp_idx], dim=1), torch.cat([edge_attr, rrwp_val], dim=0), From 0868b96e7a5aed4e2648f6288d85073a260afd44 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:41 -0500 Subject: [PATCH 18/62] match feature dimensions Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/grit_pretraining.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/config/grit_pretraining.yaml b/examples/config/grit_pretraining.yaml index 73e5817..52c44c8 100644 --- a/examples/config/grit_pretraining.yaml +++ b/examples/config/grit_pretraining.yaml @@ -36,7 +36,7 @@ model: attention_head: 8 dropout: 0.1 edge_dim: 2 - hidden_size: 123 + hidden_size: 64 # `gt.dim_hidden` must match `gnn.dim_inner` input_dim: 9 num_layers: 10 output_dim: 6 From 3cc21a38eea36b3dddca0e3a2d1817e15ac66b86 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:41 -0500 Subject: [PATCH 19/62] reformat decoder to handle batch format Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/cli.py | 5 ++ gridfm_graphkit/models/grit_transformer.py | 58 ++++++++++++++++++++-- 2 files changed, 58 insertions(+), 5 deletions(-) diff --git a/gridfm_graphkit/cli.py b/gridfm_graphkit/cli.py index a7507c1..79cb772 100644 --- a/gridfm_graphkit/cli.py +++ b/gridfm_graphkit/cli.py @@ -77,6 +77,11 @@ def main_cli(args): max_epochs=config_args.training.epochs, callbacks=get_training_callbacks(config_args), ) + + # print('******model*****') + # print(model) + # print('******model*****') + if args.command == "train" or args.command == "finetune": trainer.fit(model=model, datamodule=litGrid) diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index 50d0fec..2e85d4e 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -85,6 +85,49 @@ def forward(self, batch): for module in self.children(): batch = module(batch) return batch + +class GraphHead(nn.Module): + """ + Prediction head for graph prediction tasks. + Args: + dim_in (int): Input dimension. + dim_out (int): Output dimension. For binary prediction, dim_out=1. + L (int): Number of hidden layers. + """ + + def __init__(self, dim_in, dim_out): + super().__init__() + # self.deg_scaler = False + # self.fwl = False + + # list_FC_layers = [ + # nn.Linear(dim_in // 2 ** l, dim_in // 2 ** (l + 1), bias=True) + # for l in range(L)] + # list_FC_layers.append( + # nn.Linear(dim_in // 2 ** L, dim_out, bias=True)) + self.FC_layers = nn.Sequential( + nn.Linear(dim_in, dim_in), + nn.LeakyReLU(), + nn.Linear(dim_in, dim_out), + ) #nn.ModuleList(list_FC_layers) + # self.L = L + # self.activation = register.act_dict[cfg.gnn.act]() + # note: modified to add () in the end from original code of 'GPS' + # potentially due to the change of PyG/GraphGym version + + def _apply_index(self, batch): + return batch.graph_feature, batch.y + + def forward(self, batch): + # graph_emb = self.pooling_fun(batch.x, batch.batch) + graph_emb = self.FC_layers(batch.x) + # for l in range(self.L): + # graph_emb = self.FC_layers[l](graph_emb) + # graph_emb = self.activation(graph_emb) + # graph_emb = self.FC_layers[self.L](graph_emb) + batch.graph_feature = graph_emb + pred, label = self._apply_index(batch) + return pred @MODELS_REGISTRY.register("GRIT") @@ -161,15 +204,20 @@ def __init__(self, args): self.layers = nn.Sequential(*layers) - self.decoder = nn.Sequential( - nn.Linear(dim_inner, dim_inner), - nn.LeakyReLU(), - nn.Linear(dim_inner, dim_out), - ) + # self.decoder = nn.Sequential( + # nn.Linear(dim_inner, dim_inner), + # nn.LeakyReLU(), + # nn.Linear(dim_inner, dim_out), + # ) + + self.decoder = GraphHead(dim_inner, dim_out) def forward(self, batch): # self, x, pe, edge_index, edge_attr, batch # gps parameters print('process--->>', batch) # TODO remove print for module in self.children(): + print('----------') + print(module) batch = module(batch) + print('--passed--') return batch \ No newline at end of file From 17830516f74d297db439b659a14edf8db99506f3 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:42 -0500 Subject: [PATCH 20/62] confirmed training loop functions Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/models/grit_transformer.py | 8 ++++---- gridfm_graphkit/models/rrwp_encoder.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index 2e85d4e..4e09de9 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -213,11 +213,11 @@ def __init__(self, args): self.decoder = GraphHead(dim_inner, dim_out) def forward(self, batch): # self, x, pe, edge_index, edge_attr, batch # gps parameters - print('process--->>', batch) # TODO remove print + # print('process--->>', batch) # TODO remove print for module in self.children(): - print('----------') - print(module) + # print('----------') + # print(module) batch = module(batch) - print('--passed--') + # print('--passed--') return batch \ No newline at end of file diff --git a/gridfm_graphkit/models/rrwp_encoder.py b/gridfm_graphkit/models/rrwp_encoder.py index 33c5215..270ca86 100644 --- a/gridfm_graphkit/models/rrwp_encoder.py +++ b/gridfm_graphkit/models/rrwp_encoder.py @@ -145,8 +145,8 @@ def forward(self, batch): else: # edge_index, edge_attr = add_remaining_self_loops(edge_index, edge_attr, num_nodes=batch.num_nodes, fill_value=0.) edge_index, edge_attr = add_self_loops(edge_index, edge_attr, num_nodes=batch.num_nodes, fill_value=0.) - print('xxxx', edge_attr.size(), rrwp_val.size()) - print('yyyy', edge_index.size(), rrwp_idx.size()) + # print('xxxx', edge_attr.size(), rrwp_val.size()) + # print('yyyy', edge_index.size(), rrwp_idx.size()) out_idx, out_val = torch_sparse.coalesce( torch.cat([edge_index, rrwp_idx], dim=1), torch.cat([edge_attr, rrwp_val], dim=0), From c75012fa845f0775f47df55eda7f2fe2c85d25d9 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:42 -0500 Subject: [PATCH 21/62] update toml Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/grit_pretraining.yaml | 22 +++++++++++----------- gridfm_graphkit/models/rrwp_encoder.py | 2 +- pyproject.toml | 1 + 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/examples/config/grit_pretraining.yaml b/examples/config/grit_pretraining.yaml index 52c44c8..30ee213 100644 --- a/examples/config/grit_pretraining.yaml +++ b/examples/config/grit_pretraining.yaml @@ -11,21 +11,21 @@ data: networks: # - Texas2k_case1_2016summerpeak - case24_ieee_rts - # - case118_ieee - # - case300_ieee + - case118_ieee + - case300_ieee - case89_pegase - # - case240_pserc + - case240_pserc normalization: baseMVAnorm scenarios: # - 5000 - - 5000 - - 5000 - # - 30000 - # - 50000 - # - 50000 + - 50000 + - 50000 + - 30000 + - 50000 + - 50000 test_ratio: 0.1 val_ratio: 0.1 - workers: 4 + workers: 8 posenc_RRWP: # TODO maybe better with data section... enable: True ksteps: 21 @@ -36,7 +36,7 @@ model: attention_head: 8 dropout: 0.1 edge_dim: 2 - hidden_size: 64 # `gt.dim_hidden` must match `gnn.dim_inner` + hidden_size: 116 # `gt.dim_hidden` must match `gnn.dim_inner` input_dim: 9 num_layers: 10 output_dim: 6 @@ -51,7 +51,7 @@ model: edge_encoder_bn: True gt: layer_type: GritTransformer - dim_hidden: 64 # `gt.dim_hidden` must match `gnn.dim_inner` + dim_hidden: 116 # `gt.dim_hidden` must match `gnn.dim_inner` layer_norm: False batch_norm: True update_e: True diff --git a/gridfm_graphkit/models/rrwp_encoder.py b/gridfm_graphkit/models/rrwp_encoder.py index 270ca86..2dadd35 100644 --- a/gridfm_graphkit/models/rrwp_encoder.py +++ b/gridfm_graphkit/models/rrwp_encoder.py @@ -114,7 +114,7 @@ def __init__(self, emb_dim, out_dim, batchnorm=False, layernorm=False, use_bias= if self.batchnorm or self.layernorm: warnings.warn("batchnorm/layernorm might ruin some properties of pe on providing shortest-path distance info ") - print('--------fc in and out:', emb_dim, out_dim) + # print('--------fc in and out:', emb_dim, out_dim) self.fc = nn.Linear(emb_dim, out_dim, bias=use_bias) torch.nn.init.xavier_uniform_(self.fc.weight) self.pad_to_full_graph = pad_to_full_graph diff --git a/pyproject.toml b/pyproject.toml index 51c8665..2250ae9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,7 @@ dependencies = [ "pyyaml", "lightning", "seaborn", + "opt-einsum", ] [project.optional-dependencies] From 3d3f98b3123defb965144e8d50868c46bfdab455 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:42 -0500 Subject: [PATCH 22/62] added forward method to transform class Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/datasets/posenc_stats.py | 6 +++++- gridfm_graphkit/models/grit_transformer.py | 4 ++-- pyproject.toml | 2 +- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/gridfm_graphkit/datasets/posenc_stats.py b/gridfm_graphkit/datasets/posenc_stats.py index 8bb2b9d..21e7841 100644 --- a/gridfm_graphkit/datasets/posenc_stats.py +++ b/gridfm_graphkit/datasets/posenc_stats.py @@ -13,6 +13,7 @@ from torch_geometric.transforms import BaseTransform from torch_geometric.data import Data +from typing import Any def compute_posenc_stats(data, pe_types, cfg): """Precompute positional encodings for the given graph. @@ -58,9 +59,12 @@ def __init__(self, pe_types, cfg): self.pe_types = pe_types self.cfg = cfg + def forward(self, data: Any) -> Any: + pass + def __call__(self, data: Data) -> Data: data = compute_posenc_stats(data, pe_types=self.pe_types, cfg=self.cfg ) - return data \ No newline at end of file + return data diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index 4e09de9..7caed0f 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -213,11 +213,11 @@ def __init__(self, args): self.decoder = GraphHead(dim_inner, dim_out) def forward(self, batch): # self, x, pe, edge_index, edge_attr, batch # gps parameters - # print('process--->>', batch) # TODO remove print + #print('process--->>', batch) # TODO remove print for module in self.children(): # print('----------') # print(module) batch = module(batch) # print('--passed--') - return batch \ No newline at end of file + return batch diff --git a/pyproject.toml b/pyproject.toml index 2250ae9..4a17ed5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ classifiers = [ dependencies = [ - "torch>2.0", + "torch==2.6", "torch-geometric", "mlflow", "nbformat", From d238e7591d4f358ac2877f838debd4ab64aa7afe Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:43 -0500 Subject: [PATCH 23/62] update readme with install instructions Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 78e8661..5a096dc 100644 --- a/README.md +++ b/README.md @@ -34,6 +34,7 @@ cd gridfm-graphkit python -m venv venv source venv/bin/activate pip install -e . +pip install torch_sparse torch_scatter -f https://data.pyg.org/whl/torch-2.6.0+cu124.html ``` For documentation generation and unit testing, install with the optional `dev` and `test` extras: From 17b0889e339f0378d761afe047f5b08db0802d03 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:43 -0500 Subject: [PATCH 24/62] verifed compat with GPS and GNN Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/grit_pretraining.yaml | 2 +- gridfm_graphkit/datasets/powergrid_datamodule.py | 2 +- gridfm_graphkit/models/gnn_transformer.py | 8 +++++++- gridfm_graphkit/models/gps_transformer.py | 8 +++++++- 4 files changed, 16 insertions(+), 4 deletions(-) diff --git a/examples/config/grit_pretraining.yaml b/examples/config/grit_pretraining.yaml index 30ee213..8f11c93 100644 --- a/examples/config/grit_pretraining.yaml +++ b/examples/config/grit_pretraining.yaml @@ -25,7 +25,7 @@ data: - 50000 test_ratio: 0.1 val_ratio: 0.1 - workers: 8 + workers: 0 posenc_RRWP: # TODO maybe better with data section... enable: True ksteps: 21 diff --git a/gridfm_graphkit/datasets/powergrid_datamodule.py b/gridfm_graphkit/datasets/powergrid_datamodule.py index e67956e..4b0320f 100644 --- a/gridfm_graphkit/datasets/powergrid_datamodule.py +++ b/gridfm_graphkit/datasets/powergrid_datamodule.py @@ -135,7 +135,7 @@ def setup(self, stage: str): transform=get_transform(args=self.args), ) - if self.args.data.posenc_RRWP.enable: + if ('posenc_RRWP' in self.args.data) and self.args.data.posenc_RRWP.enable: pe_transform = ComputePosencStat(pe_types=['RRWP'], cfg=self.args.data ) diff --git a/gridfm_graphkit/models/gnn_transformer.py b/gridfm_graphkit/models/gnn_transformer.py index 9e1ab23..37d3632 100644 --- a/gridfm_graphkit/models/gnn_transformer.py +++ b/gridfm_graphkit/models/gnn_transformer.py @@ -74,7 +74,7 @@ def __init__(self, args): requires_grad=False, ) - def forward(self, x, pe, edge_index, edge_attr, batch): + def forward(self, data_batch): """ Forward pass for the GPSTransformer. @@ -88,6 +88,12 @@ def forward(self, x, pe, edge_index, edge_attr, batch): Returns: output (Tensor): Output node features of shape [num_nodes, output_dim]. """ + x=data_batch.x + pe=data_batch.pe + edge_index=data_batch.edge_index + edge_attr=data_batch.edge_attr + batch=data_batch.batch + for conv in self.layers: x = conv(x, edge_index, edge_attr) x = nn.LeakyReLU()(x) diff --git a/gridfm_graphkit/models/gps_transformer.py b/gridfm_graphkit/models/gps_transformer.py index cc8b648..ca45c5a 100644 --- a/gridfm_graphkit/models/gps_transformer.py +++ b/gridfm_graphkit/models/gps_transformer.py @@ -105,7 +105,7 @@ def __init__(self, args): requires_grad=False, ) - def forward(self, x, pe, edge_index, edge_attr, batch): + def forward(self, data_batch): """ Forward pass for the GPSTransformer. @@ -119,6 +119,12 @@ def forward(self, x, pe, edge_index, edge_attr, batch): Returns: output (Tensor): Output node features of shape [num_nodes, output_dim]. """ + x=data_batch.x + pe=data_batch.pe + edge_index=data_batch.edge_index + edge_attr=data_batch.edge_attr + batch=data_batch.batch + x_pe = self.pe_norm(pe) x = self.encoder(x) From 091f0849ce997ff6d820b2102fb4d631faa7f789 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:43 -0500 Subject: [PATCH 25/62] work on comments and clean up Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/datasets/rrwp.py | 19 +------ gridfm_graphkit/models/grit_layer.py | 12 ++-- gridfm_graphkit/models/grit_transformer.py | 55 +++++++------------ gridfm_graphkit/models/rrwp_encoder.py | 27 ++++----- .../tasks/feature_reconstruction_task.py | 7 --- 5 files changed, 42 insertions(+), 78 deletions(-) diff --git a/gridfm_graphkit/datasets/rrwp.py b/gridfm_graphkit/datasets/rrwp.py index 26218f0..acbe112 100644 --- a/gridfm_graphkit/datasets/rrwp.py +++ b/gridfm_graphkit/datasets/rrwp.py @@ -1,20 +1,7 @@ -# ------------------------ : new rwpse ---------------- -from typing import Union, Any, Optional -import numpy as np +from typing import Any, Optional import torch import torch.nn.functional as F -import torch_geometric as pyg -from torch_geometric.data import Data, HeteroData -from torch_geometric.transforms import BaseTransform -from torch_scatter import scatter, scatter_add, scatter_max - - -from torch_geometric.utils import ( - get_laplacian, - get_self_loop_attr, - to_scipy_sparse_matrix, -) -import torch_sparse +from torch_geometric.data import Data from torch_sparse import SparseTensor @@ -42,8 +29,6 @@ def add_full_rrwp(data, spd=False, **kwargs ): - device=data.edge_index.device - ind_vec = torch.eye(walk_length, dtype=torch.float, device=device) num_nodes = data.num_nodes edge_index, edge_weight = data.edge_index, data.edge_weight diff --git a/gridfm_graphkit/models/grit_layer.py b/gridfm_graphkit/models/grit_layer.py index f95ffc7..a1ffc4a 100644 --- a/gridfm_graphkit/models/grit_layer.py +++ b/gridfm_graphkit/models/grit_layer.py @@ -8,11 +8,12 @@ import opt_einsum as oe - import warnings + def pyg_softmax(src, index, num_nodes=None): - r"""Computes a sparsely evaluated softmax. + """ + Computes a sparsely evaluated softmax. Given a value tensor :attr:`src`, this function first groups the values along the first dimension based on the indices specified in :attr:`index`, and then proceeds to compute the softmax individually for each group. @@ -23,7 +24,8 @@ def pyg_softmax(src, index, num_nodes=None): num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`) - :rtype: :class:`Tensor` + Returns: + out (Tensor) """ num_nodes = maybe_num_nodes(index, num_nodes) @@ -39,7 +41,7 @@ def pyg_softmax(src, index, num_nodes=None): class MultiHeadAttentionLayerGritSparse(nn.Module): """ - Proposed Attention Computation for GRIT + Attention Computation for GRIT """ def __init__(self, in_dim, out_dim, num_heads, use_bias, @@ -140,7 +142,7 @@ def forward(self, batch): class GritTransformerLayer(nn.Module): """ - Proposed Transformer Layer for GRIT + Transformer Layer for GRIT """ def __init__(self, in_dim, out_dim, num_heads, dropout=0.0, diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index 7caed0f..a1717d1 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -12,7 +12,8 @@ class BatchNorm1dNode(torch.nn.Module): Args: dim_in (int): BatchNorm input dimension. - TODO fill in comments + eps (float): BatchNorm eps. + momentum (float): BatchNorm momentum. """ def __init__(self, dim_in, eps, momentum): super().__init__() @@ -88,7 +89,7 @@ def forward(self, batch): class GraphHead(nn.Module): """ - Prediction head for graph prediction tasks. + Prediction head for decoding tasks. Args: dim_in (int): Input dimension. dim_out (int): Output dimension. For binary prediction, dim_out=1. @@ -97,34 +98,18 @@ class GraphHead(nn.Module): def __init__(self, dim_in, dim_out): super().__init__() - # self.deg_scaler = False - # self.fwl = False - - # list_FC_layers = [ - # nn.Linear(dim_in // 2 ** l, dim_in // 2 ** (l + 1), bias=True) - # for l in range(L)] - # list_FC_layers.append( - # nn.Linear(dim_in // 2 ** L, dim_out, bias=True)) + self.FC_layers = nn.Sequential( nn.Linear(dim_in, dim_in), nn.LeakyReLU(), nn.Linear(dim_in, dim_out), - ) #nn.ModuleList(list_FC_layers) - # self.L = L - # self.activation = register.act_dict[cfg.gnn.act]() - # note: modified to add () in the end from original code of 'GPS' - # potentially due to the change of PyG/GraphGym version + ) def _apply_index(self, batch): return batch.graph_feature, batch.y def forward(self, batch): - # graph_emb = self.pooling_fun(batch.x, batch.batch) graph_emb = self.FC_layers(batch.x) - # for l in range(self.L): - # graph_emb = self.FC_layers[l](graph_emb) - # graph_emb = self.activation(graph_emb) - # graph_emb = self.FC_layers[self.L](graph_emb) batch.graph_feature = graph_emb pred, label = self._apply_index(batch) return pred @@ -132,9 +117,12 @@ def forward(self, batch): @MODELS_REGISTRY.register("GRIT") class GritTransformer(torch.nn.Module): - ''' - The proposed GritTransformer (Graph Inductive Bias Transformer) - ''' + """ + The GritTransformer (Graph Inductive Bias Transformer) from + Graph Inductive Biases in Transformers without Message Passing, L. Ma et al., + 2023. + + """ def __init__(self, args): super().__init__() @@ -204,20 +192,19 @@ def __init__(self, args): self.layers = nn.Sequential(*layers) - # self.decoder = nn.Sequential( - # nn.Linear(dim_inner, dim_inner), - # nn.LeakyReLU(), - # nn.Linear(dim_inner, dim_out), - # ) - self.decoder = GraphHead(dim_inner, dim_out) - def forward(self, batch): # self, x, pe, edge_index, edge_attr, batch # gps parameters - #print('process--->>', batch) # TODO remove print + def forward(self, batch): + """ + Forward pass for GRIT. + + Args: + batch (Batch): Pytorch Geometric Batch object, with x, y encodings, etc. + + Returns: + output (Tensor): Output node features of shape [num_nodes, output_dim]. + """ for module in self.children(): - # print('----------') - # print(module) batch = module(batch) - # print('--passed--') return batch diff --git a/gridfm_graphkit/models/rrwp_encoder.py b/gridfm_graphkit/models/rrwp_encoder.py index 2dadd35..1f7fd10 100644 --- a/gridfm_graphkit/models/rrwp_encoder.py +++ b/gridfm_graphkit/models/rrwp_encoder.py @@ -51,10 +51,10 @@ def full_edge_index(edge_index, batch=None): class RRWPLinearNodeEncoder(torch.nn.Module): """ - FC_1(RRWP) + FC_2 (Node-attr) - note: FC_2 is given by the Typedict encoder of node-attr in some cases - Parameters: - num_classes - the number of classes for the embedding mapping to learn + FC_1(RRWP) + FC_2 (Node-attr) + note: FC_2 is given by the Typedict encoder of node-attr in some cases + Parameters: + num_classes - the number of classes for the embedding mapping to learn """ def __init__(self, emb_dim, out_dim, use_bias=False, batchnorm=False, layernorm=False, pe_name="rrwp"): super().__init__() @@ -90,14 +90,14 @@ def forward(self, batch): class RRWPLinearEdgeEncoder(torch.nn.Module): - ''' - Merge RRWP with given edge-attr and Zero-padding to all pairs of node - FC_1(RRWP) + FC_2(edge-attr) - - FC_2 given by the TypedictEncoder in same cases - - Zero-padding for non-existing edges in fully-connected graph - - (optional) add node-attr as the E_{i,i}'s attr - note: assuming node-attr and edge-attr is with the same dimension after Encoders - ''' + """ + Merge RRWP with given edge-attr and Zero-padding to all pairs of node + FC_1(RRWP) + FC_2(edge-attr) + - FC_2 given by the TypedictEncoder in same cases + - Zero-padding for non-existing edges in fully-connected graph + - (optional) add node-attr as the E_{i,i}'s attr + note: assuming node-attr and edge-attr is with the same dimension after Encoders + """ def __init__(self, emb_dim, out_dim, batchnorm=False, layernorm=False, use_bias=False, pad_to_full_graph=True, fill_value=0., add_node_attr_as_self_loop=False, @@ -143,10 +143,7 @@ def forward(self, batch): if self.overwrite_old_attr: out_idx, out_val = rrwp_idx, rrwp_val else: - # edge_index, edge_attr = add_remaining_self_loops(edge_index, edge_attr, num_nodes=batch.num_nodes, fill_value=0.) edge_index, edge_attr = add_self_loops(edge_index, edge_attr, num_nodes=batch.num_nodes, fill_value=0.) - # print('xxxx', edge_attr.size(), rrwp_val.size()) - # print('yyyy', edge_index.size(), rrwp_idx.size()) out_idx, out_val = torch_sparse.coalesce( torch.cat([edge_index, rrwp_idx], dim=1), torch.cat([edge_attr, rrwp_val], dim=0), diff --git a/gridfm_graphkit/tasks/feature_reconstruction_task.py b/gridfm_graphkit/tasks/feature_reconstruction_task.py index da2f478..0d1743b 100644 --- a/gridfm_graphkit/tasks/feature_reconstruction_task.py +++ b/gridfm_graphkit/tasks/feature_reconstruction_task.py @@ -111,13 +111,6 @@ def on_fit_start(self): def shared_step(self, batch): output = self.forward( - # TODO update args list in the GPS Transf. for consistency - # x=batch.x, - # pe=batch.pe, - # edge_index=batch.edge_index, - # edge_attr=batch.edge_attr, - # batch=batch.batch, - # mask=batch.mask, batch ) From 53d564415e4a8c99f10c7882c343bf38fdadf766 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:07:44 -0500 Subject: [PATCH 26/62] deep copy in test method Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/tasks/feature_reconstruction_task.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/gridfm_graphkit/tasks/feature_reconstruction_task.py b/gridfm_graphkit/tasks/feature_reconstruction_task.py index 0d1743b..e6a4749 100644 --- a/gridfm_graphkit/tasks/feature_reconstruction_task.py +++ b/gridfm_graphkit/tasks/feature_reconstruction_task.py @@ -5,6 +5,7 @@ import numpy as np import os import pandas as pd +import copy from lightning.pytorch.loggers import MLFlowLogger from gridfm_graphkit.io.param_handler import load_model, get_loss_function @@ -162,7 +163,7 @@ def validation_step(self, batch, batch_idx): return loss_dict["loss"] def test_step(self, batch, batch_idx, dataloader_idx=0): - output, loss_dict = self.shared_step(batch) + output, loss_dict = self.shared_step(copy.deepcopy(batch)) dataset_name = self.args.data.networks[dataloader_idx] From e23c9c62fe6698a397e8bc160d41e9fcbe332eef Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 24 Nov 2025 13:13:19 -0500 Subject: [PATCH 27/62] basic RWSE flown over Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/grit_pretraining.yaml | 18 +++-- gridfm_graphkit/datasets/posenc_stats.py | 67 +++++++++++++++++++ .../datasets/powergrid_datamodule.py | 8 +++ gridfm_graphkit/models/grit_transformer.py | 2 + .../tasks/feature_reconstruction_task.py | 2 + 5 files changed, 90 insertions(+), 7 deletions(-) diff --git a/examples/config/grit_pretraining.yaml b/examples/config/grit_pretraining.yaml index 8f11c93..e0cdf3e 100644 --- a/examples/config/grit_pretraining.yaml +++ b/examples/config/grit_pretraining.yaml @@ -18,20 +18,24 @@ data: normalization: baseMVAnorm scenarios: # - 5000 - - 50000 - - 50000 - - 30000 - - 50000 - - 50000 + - 5000 + - 5000 + - 3000 + - 5000 + - 5000 test_ratio: 0.1 val_ratio: 0.1 workers: 0 - posenc_RRWP: # TODO maybe better with data section... - enable: True + posenc_RRWP: + enable: False ksteps: 21 add_identity: True add_node_attr: False add_inverse: False + posenc_RWSE: # TODO verify functionality + enable: True + kernel: + times: 21 model: attention_head: 8 dropout: 0.1 diff --git a/gridfm_graphkit/datasets/posenc_stats.py b/gridfm_graphkit/datasets/posenc_stats.py index 21e7841..0588d87 100644 --- a/gridfm_graphkit/datasets/posenc_stats.py +++ b/gridfm_graphkit/datasets/posenc_stats.py @@ -15,6 +15,8 @@ from torch_geometric.data import Data from typing import Any +from torch_geometric.utils.num_nodes import maybe_num_nodes + def compute_posenc_stats(data, pe_types, cfg): """Precompute positional encodings for the given graph. Supported PE statistics to precompute in original implementation, @@ -51,9 +53,74 @@ def compute_posenc_stats(data, pe_types, cfg): ) data = transform(data) + # Random Walks. + if 'RWSE' in pe_types: + kernel_param = cfg.posenc_RWSE.kernel + if hasattr(data, 'num_nodes'): + N = data.num_nodes # Explicitly given number of nodes, e.g. ogbg-ppa + else: + N = data.x.shape[0] # Number of nodes, including disconnected nodes. + if len(kernel_param.times) == 0: + raise ValueError("List of kernel times required for RWSE") + rw_landing = get_rw_landing_probs( + ksteps=[xx + 1 for xx in range(kernel_param.times)], + edge_index=data.edge_index, + num_nodes=N + ) + data.pestat_RWSE = rw_landing + return data + +def get_rw_landing_probs(ksteps, edge_index, edge_weight=None, + num_nodes=None, space_dim=0): + """Compute Random Walk landing probabilities for given list of K steps. + + Args: + ksteps: List of k-steps for which to compute the RW landings + edge_index: PyG sparse representation of the graph + edge_weight: (optional) Edge weights + num_nodes: (optional) Number of nodes in the graph + space_dim: (optional) Estimated dimensionality of the space. Used to + correct the random-walk diagonal by a factor `k^(space_dim/2)`. + In euclidean space, this correction means that the height of + the gaussian distribution stays almost constant across the number of + steps, if `space_dim` is the dimension of the euclidean space. + + Returns: + 2D Tensor with shape (num_nodes, len(ksteps)) with RW landing probs + """ + if edge_weight is None: + edge_weight = torch.ones(edge_index.size(1), device=edge_index.device) + num_nodes = maybe_num_nodes(edge_index, num_nodes) + source, dest = edge_index[0], edge_index[1] + deg = scatter_add(edge_weight, source, dim=0, dim_size=num_nodes) # Out degrees. + deg_inv = deg.pow(-1.) + deg_inv.masked_fill_(deg_inv == float('inf'), 0) + + if edge_index.numel() == 0: + P = edge_index.new_zeros((1, num_nodes, num_nodes)) + else: + # P = D^-1 * A + P = torch.diag(deg_inv) @ to_dense_adj(edge_index, max_num_nodes=num_nodes) # 1 x (Num nodes) x (Num nodes) + rws = [] + if ksteps == list(range(min(ksteps), max(ksteps) + 1)): + # Efficient way if ksteps are a consecutive sequence (most of the time the case) + Pk = P.clone().detach().matrix_power(min(ksteps)) + for k in range(min(ksteps), max(ksteps) + 1): + rws.append(torch.diagonal(Pk, dim1=-2, dim2=-1) * \ + (k ** (space_dim / 2))) + Pk = Pk @ P + else: + # Explicitly raising P to power k for each k \in ksteps. + for k in ksteps: + rws.append(torch.diagonal(P.matrix_power(k), dim1=-2, dim2=-1) * \ + (k ** (space_dim / 2))) + rw_landing = torch.cat(rws, dim=0).transpose(0, 1) # (Num nodes) x (K steps) + + return rw_landing + class ComputePosencStat(BaseTransform): def __init__(self, pe_types, cfg): self.pe_types = pe_types diff --git a/gridfm_graphkit/datasets/powergrid_datamodule.py b/gridfm_graphkit/datasets/powergrid_datamodule.py index 4b0320f..740ff51 100644 --- a/gridfm_graphkit/datasets/powergrid_datamodule.py +++ b/gridfm_graphkit/datasets/powergrid_datamodule.py @@ -143,6 +143,14 @@ def setup(self, stage: str): dataset.transform = pe_transform else: dataset.transform = T.Compose([pe_transform, dataset.transform]) + if ('posenc_RWSE' in self.args.data) and self.args.data.posenc_RWSE.enable: + pe_transform = ComputePosencStat(pe_types=['RWSE'], + cfg=self.args.data + ) + if dataset.transform is None: + dataset.transform = pe_transform + else: + dataset.transform = T.Compose([pe_transform, dataset.transform]) self.datasets.append(dataset) diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index a1717d1..e264d1b 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -204,6 +204,8 @@ def forward(self, batch): Returns: output (Tensor): Output node features of shape [num_nodes, output_dim]. """ + # print('xxxx',batch.x.min(), batch.x.max()) + # print('yyyyy',batch.y.min(), batch.y.max()) for module in self.children(): batch = module(batch) diff --git a/gridfm_graphkit/tasks/feature_reconstruction_task.py b/gridfm_graphkit/tasks/feature_reconstruction_task.py index e6a4749..aa3370c 100644 --- a/gridfm_graphkit/tasks/feature_reconstruction_task.py +++ b/gridfm_graphkit/tasks/feature_reconstruction_task.py @@ -77,6 +77,8 @@ def __init__(self, args, node_normalizers, edge_normalizers): def forward(self, batch): if batch.mask is not None: + # print('xxxx',batch.x.min(), batch.x.max()) + # print('yyyyy',batch.y.min(), batch.y.max()) mask_value_expanded = self.model.mask_value.expand(batch.x.shape[0], -1) batch.x[:, : batch.mask.shape[1]][batch.mask] = mask_value_expanded[batch.mask] return self.model(batch) From bfe2af0f2a447af11fc244d48c49ea2c2ebf4aa0 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 24 Nov 2025 13:24:31 -0500 Subject: [PATCH 28/62] tested addition of RWSE Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/datasets/posenc_stats.py | 2 +- gridfm_graphkit/models/grit_transformer.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/gridfm_graphkit/datasets/posenc_stats.py b/gridfm_graphkit/datasets/posenc_stats.py index 0588d87..e12d9f1 100644 --- a/gridfm_graphkit/datasets/posenc_stats.py +++ b/gridfm_graphkit/datasets/posenc_stats.py @@ -60,7 +60,7 @@ def compute_posenc_stats(data, pe_types, cfg): N = data.num_nodes # Explicitly given number of nodes, e.g. ogbg-ppa else: N = data.x.shape[0] # Number of nodes, including disconnected nodes. - if len(kernel_param.times) == 0: + if kernel_param.times == 0: raise ValueError("List of kernel times required for RWSE") rw_landing = get_rw_landing_probs( ksteps=[xx + 1 for xx in range(kernel_param.times)], diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index e264d1b..cab5ef9 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -206,6 +206,7 @@ def forward(self, batch): """ # print('xxxx',batch.x.min(), batch.x.max()) # print('yyyyy',batch.y.min(), batch.y.max()) + print('>>>>', batch) for module in self.children(): batch = module(batch) From c1e572181a704a86dab158a3f81a6696de80b78f Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 24 Nov 2025 14:26:42 -0500 Subject: [PATCH 29/62] flow over kernel encoders Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/grit_pretraining.yaml | 2 +- gridfm_graphkit/models/grit_transformer.py | 7 +- gridfm_graphkit/models/kernel_pos_encoder.py | 123 +++++++++++++++++++ 3 files changed, 129 insertions(+), 3 deletions(-) create mode 100644 gridfm_graphkit/models/kernel_pos_encoder.py diff --git a/examples/config/grit_pretraining.yaml b/examples/config/grit_pretraining.yaml index e0cdf3e..f8654b9 100644 --- a/examples/config/grit_pretraining.yaml +++ b/examples/config/grit_pretraining.yaml @@ -50,7 +50,7 @@ model: encoder: node_encoder: True edge_encoder: True - node_encoder_name: TODO + node_encoder_name: RWSE node_encoder_bn: True edge_encoder_bn: True gt: diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index cab5ef9..19351d9 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -69,7 +69,10 @@ def __init__( self.dim_in = dim_in if args.encoder.node_encoder: # Encode integer node features via nn.Embeddings - self.node_encoder = LinearNodeEncoder(self.dim_in, dim_inner) + if 'RWSE' in self.node_encoder_name: + self.node_encoder = RWSENodeEncoder(self.dim_in, dim_inner) + else: + self.node_encoder = LinearNodeEncoder(self.dim_in, dim_inner) if args.encoder.node_encoder_bn: self.node_encoder_bn = BatchNorm1dNode(dim_inner, 1e-5, 0.1) # Update dim_in to reflect the new dimension fo the node features @@ -206,7 +209,7 @@ def forward(self, batch): """ # print('xxxx',batch.x.min(), batch.x.max()) # print('yyyyy',batch.y.min(), batch.y.max()) - print('>>>>', batch) + # print('>>>>', batch) for module in self.children(): batch = module(batch) diff --git a/gridfm_graphkit/models/kernel_pos_encoder.py b/gridfm_graphkit/models/kernel_pos_encoder.py new file mode 100644 index 0000000..36c4c51 --- /dev/null +++ b/gridfm_graphkit/models/kernel_pos_encoder.py @@ -0,0 +1,123 @@ +import torch +import torch.nn as nn +from torch_geometric.graphgym.config import cfg +from torch_geometric.graphgym.register import register_node_encoder + + +class KernelPENodeEncoder(torch.nn.Module): + """Configurable kernel-based Positional Encoding node encoder. + + The choice of which kernel-based statistics to use is configurable through + setting of `kernel_type`. Based on this, the appropriate config is selected, + and also the appropriate variable with precomputed kernel stats is then + selected from PyG Data graphs in `forward` function. + E.g., supported are 'RWSE', 'HKdiagSE', 'ElstaticSE'. + + PE of size `dim_pe` will get appended to each node feature vector. + If `expand_x` set True, original node features will be first linearly + projected to (dim_emb - dim_pe) size and the concatenated with PE. + + Args: + dim_emb: Size of final node embedding + expand_x: Expand node features `x` from dim_in to (dim_emb - dim_pe) + """ + + kernel_type = None # Instantiated type of the KernelPE, e.g. RWSE + + def __init__(self, dim_emb, expand_x=True): + super().__init__() + if self.kernel_type is None: + raise ValueError(f"{self.__class__.__name__} has to be " + f"preconfigured by setting 'kernel_type' class" + f"variable before calling the constructor.") + + dim_in = cfg.share.dim_in # Expected original input node features dim + + pecfg = getattr(cfg, f"posenc_{self.kernel_type}") + dim_pe = pecfg.dim_pe # Size of the kernel-based PE embedding + num_rw_steps = len(pecfg.kernel.times) + model_type = pecfg.model.lower() # Encoder NN model type for PEs + n_layers = pecfg.layers # Num. layers in PE encoder model + norm_type = pecfg.raw_norm_type.lower() # Raw PE normalization layer type + self.pass_as_var = pecfg.pass_as_var # Pass PE also as a separate variable + + if dim_emb - dim_pe < 1: + raise ValueError(f"PE dim size {dim_pe} is too large for " + f"desired embedding size of {dim_emb}.") + + if expand_x: + self.linear_x = nn.Linear(dim_in, dim_emb - dim_pe) + self.expand_x = expand_x + + if norm_type == 'batchnorm': + self.raw_norm = nn.BatchNorm1d(num_rw_steps) + else: + self.raw_norm = None + + activation = nn.ReLU() # register.act_dict[cfg.gnn.act] + if model_type == 'mlp': + layers = [] + if n_layers == 1: + layers.append(nn.Linear(num_rw_steps, dim_pe)) + layers.append(activation) + else: + layers.append(nn.Linear(num_rw_steps, 2 * dim_pe)) + layers.append(activation) + for _ in range(n_layers - 2): + layers.append(nn.Linear(2 * dim_pe, 2 * dim_pe)) + layers.append(activation) + layers.append(nn.Linear(2 * dim_pe, dim_pe)) + layers.append(activation) + self.pe_encoder = nn.Sequential(*layers) + elif model_type == 'linear': + self.pe_encoder = nn.Linear(num_rw_steps, dim_pe) + else: + raise ValueError(f"{self.__class__.__name__}: Does not support " + f"'{model_type}' encoder model.") + + def forward(self, batch): + pestat_var = f"pestat_{self.kernel_type}" + if not hasattr(batch, pestat_var): + raise ValueError(f"Precomputed '{pestat_var}' variable is " + f"required for {self.__class__.__name__}; set " + f"config 'posenc_{self.kernel_type}.enable' to " + f"True, and also set 'posenc.kernel.times' values") + + pos_enc = getattr(batch, pestat_var) # (Num nodes) x (Num kernel times) + # pos_enc = batch.rw_landing # (Num nodes) x (Num kernel times) + if self.raw_norm: + pos_enc = self.raw_norm(pos_enc) + pos_enc = self.pe_encoder(pos_enc) # (Num nodes) x dim_pe + + # Expand node features if needed + if self.expand_x: + h = self.linear_x(batch.x) + else: + h = batch.x + # Concatenate final PEs to input embedding + batch.x = torch.cat((h, pos_enc), 1) + # Keep PE also separate in a variable (e.g. for skip connections to input) + if self.pass_as_var: + setattr(batch, f'pe_{self.kernel_type}', pos_enc) + return batch + + +@register_node_encoder('RWSE') +class RWSENodeEncoder(KernelPENodeEncoder): + """Random Walk Structural Encoding node encoder. + """ + kernel_type = 'RWSE' + + +@register_node_encoder('HKdiagSE') +class HKdiagSENodeEncoder(KernelPENodeEncoder): + """Heat kernel (diagonal) Structural Encoding node encoder. + """ + kernel_type = 'HKdiagSE' + + +@register_node_encoder('ElstaticSE') +class ElstaticSENodeEncoder(KernelPENodeEncoder): + """Electrostatic interactions Structural Encoding node encoder. + """ + kernel_type = 'ElstaticSE' From 5e096832be119ffb9092af7d0f6dcd22fcf8f54e Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 24 Nov 2025 14:39:47 -0500 Subject: [PATCH 30/62] basic match of parameters for new encoder Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/grit_pretraining.yaml | 4 +- gridfm_graphkit/models/grit_transformer.py | 4 +- gridfm_graphkit/models/kernel_pos_encoder.py | 55 +++----------------- 3 files changed, 13 insertions(+), 50 deletions(-) diff --git a/examples/config/grit_pretraining.yaml b/examples/config/grit_pretraining.yaml index f8654b9..24349db 100644 --- a/examples/config/grit_pretraining.yaml +++ b/examples/config/grit_pretraining.yaml @@ -32,10 +32,12 @@ data: add_identity: True add_node_attr: False add_inverse: False - posenc_RWSE: # TODO verify functionality + posenc_RWSE: enable: True kernel: times: 21 + pe_dim: 20 # TODO unify with model.pe_dim + raw_norm_type: batchnorm model: attention_head: 8 dropout: 0.1 diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index 19351d9..5a01c39 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -4,7 +4,7 @@ from gridfm_graphkit.models.rrwp_encoder import RRWPLinearNodeEncoder, RRWPLinearEdgeEncoder from gridfm_graphkit.models.grit_layer import GritTransformerLayer - +from gridfm_graphkit.models.kernel_pos_encoder import RWSENodeEncoder class BatchNorm1dNode(torch.nn.Module): @@ -70,7 +70,7 @@ def __init__( if args.encoder.node_encoder: # Encode integer node features via nn.Embeddings if 'RWSE' in self.node_encoder_name: - self.node_encoder = RWSENodeEncoder(self.dim_in, dim_inner) + self.node_encoder = RWSENodeEncoder(self.dim_in, dim_inner, args.posenc_RWSE) else: self.node_encoder = LinearNodeEncoder(self.dim_in, dim_inner) if args.encoder.node_encoder_bn: diff --git a/gridfm_graphkit/models/kernel_pos_encoder.py b/gridfm_graphkit/models/kernel_pos_encoder.py index 36c4c51..4b6a654 100644 --- a/gridfm_graphkit/models/kernel_pos_encoder.py +++ b/gridfm_graphkit/models/kernel_pos_encoder.py @@ -1,7 +1,5 @@ import torch import torch.nn as nn -from torch_geometric.graphgym.config import cfg -from torch_geometric.graphgym.register import register_node_encoder class KernelPENodeEncoder(torch.nn.Module): @@ -24,22 +22,17 @@ class KernelPENodeEncoder(torch.nn.Module): kernel_type = None # Instantiated type of the KernelPE, e.g. RWSE - def __init__(self, dim_emb, expand_x=True): + def __init__(self, dim_in, dim_emb, pecfg, expand_x=True): super().__init__() if self.kernel_type is None: raise ValueError(f"{self.__class__.__name__} has to be " f"preconfigured by setting 'kernel_type' class" f"variable before calling the constructor.") - dim_in = cfg.share.dim_in # Expected original input node features dim - - pecfg = getattr(cfg, f"posenc_{self.kernel_type}") dim_pe = pecfg.dim_pe # Size of the kernel-based PE embedding num_rw_steps = len(pecfg.kernel.times) - model_type = pecfg.model.lower() # Encoder NN model type for PEs - n_layers = pecfg.layers # Num. layers in PE encoder model norm_type = pecfg.raw_norm_type.lower() # Raw PE normalization layer type - self.pass_as_var = pecfg.pass_as_var # Pass PE also as a separate variable + # self.pass_as_var = pecfg.pass_as_var # Pass PE also as a separate variable if dim_emb - dim_pe < 1: raise ValueError(f"PE dim size {dim_pe} is too large for " @@ -54,26 +47,8 @@ def __init__(self, dim_emb, expand_x=True): else: self.raw_norm = None - activation = nn.ReLU() # register.act_dict[cfg.gnn.act] - if model_type == 'mlp': - layers = [] - if n_layers == 1: - layers.append(nn.Linear(num_rw_steps, dim_pe)) - layers.append(activation) - else: - layers.append(nn.Linear(num_rw_steps, 2 * dim_pe)) - layers.append(activation) - for _ in range(n_layers - 2): - layers.append(nn.Linear(2 * dim_pe, 2 * dim_pe)) - layers.append(activation) - layers.append(nn.Linear(2 * dim_pe, dim_pe)) - layers.append(activation) - self.pe_encoder = nn.Sequential(*layers) - elif model_type == 'linear': - self.pe_encoder = nn.Linear(num_rw_steps, dim_pe) - else: - raise ValueError(f"{self.__class__.__name__}: Does not support " - f"'{model_type}' encoder model.") + self.pe_encoder = nn.Linear(num_rw_steps, dim_pe) + def forward(self, batch): pestat_var = f"pestat_{self.kernel_type}" @@ -97,27 +72,13 @@ def forward(self, batch): # Concatenate final PEs to input embedding batch.x = torch.cat((h, pos_enc), 1) # Keep PE also separate in a variable (e.g. for skip connections to input) - if self.pass_as_var: - setattr(batch, f'pe_{self.kernel_type}', pos_enc) + # if self.pass_as_var: + # setattr(batch, f'pe_{self.kernel_type}', pos_enc) + return batch -@register_node_encoder('RWSE') class RWSENodeEncoder(KernelPENodeEncoder): """Random Walk Structural Encoding node encoder. """ - kernel_type = 'RWSE' - - -@register_node_encoder('HKdiagSE') -class HKdiagSENodeEncoder(KernelPENodeEncoder): - """Heat kernel (diagonal) Structural Encoding node encoder. - """ - kernel_type = 'HKdiagSE' - - -@register_node_encoder('ElstaticSE') -class ElstaticSENodeEncoder(KernelPENodeEncoder): - """Electrostatic interactions Structural Encoding node encoder. - """ - kernel_type = 'ElstaticSE' + kernel_type = 'RWSE' \ No newline at end of file From 1378770b6d31ceb7644f0cea64801c0dde25f90c Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 24 Nov 2025 14:48:54 -0500 Subject: [PATCH 31/62] tested functionality of new encoding Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/grit_pretraining.yaml | 15 +++++++++------ gridfm_graphkit/models/grit_transformer.py | 4 ++-- gridfm_graphkit/models/kernel_pos_encoder.py | 4 ++-- 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/examples/config/grit_pretraining.yaml b/examples/config/grit_pretraining.yaml index 24349db..bf56dff 100644 --- a/examples/config/grit_pretraining.yaml +++ b/examples/config/grit_pretraining.yaml @@ -25,7 +25,7 @@ data: - 5000 test_ratio: 0.1 val_ratio: 0.1 - workers: 0 + workers: 4 posenc_RRWP: enable: False ksteps: 21 @@ -33,11 +33,9 @@ data: add_node_attr: False add_inverse: False posenc_RWSE: - enable: True - kernel: - times: 21 - pe_dim: 20 # TODO unify with model.pe_dim - raw_norm_type: batchnorm + enable: True + kernel: + times: 21 # TODO unify with model model: attention_head: 8 dropout: 0.1 @@ -55,6 +53,11 @@ model: node_encoder_name: RWSE node_encoder_bn: True edge_encoder_bn: True + posenc_RWSE: + kernel: + times: 21 + pe_dim: 20 # TODO unify with model.pe_dim + raw_norm_type: batchnorm gt: layer_type: GritTransformer dim_hidden: 116 # `gt.dim_hidden` must match `gnn.dim_inner` diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index 5a01c39..ab0d51b 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -69,8 +69,8 @@ def __init__( self.dim_in = dim_in if args.encoder.node_encoder: # Encode integer node features via nn.Embeddings - if 'RWSE' in self.node_encoder_name: - self.node_encoder = RWSENodeEncoder(self.dim_in, dim_inner, args.posenc_RWSE) + if 'RWSE' in args.encoder.node_encoder_name: + self.node_encoder = RWSENodeEncoder(self.dim_in, dim_inner, args.encoder.posenc_RWSE) else: self.node_encoder = LinearNodeEncoder(self.dim_in, dim_inner) if args.encoder.node_encoder_bn: diff --git a/gridfm_graphkit/models/kernel_pos_encoder.py b/gridfm_graphkit/models/kernel_pos_encoder.py index 4b6a654..b24078d 100644 --- a/gridfm_graphkit/models/kernel_pos_encoder.py +++ b/gridfm_graphkit/models/kernel_pos_encoder.py @@ -29,8 +29,8 @@ def __init__(self, dim_in, dim_emb, pecfg, expand_x=True): f"preconfigured by setting 'kernel_type' class" f"variable before calling the constructor.") - dim_pe = pecfg.dim_pe # Size of the kernel-based PE embedding - num_rw_steps = len(pecfg.kernel.times) + dim_pe = pecfg.pe_dim # Size of the kernel-based PE embedding + num_rw_steps = pecfg.kernel.times norm_type = pecfg.raw_norm_type.lower() # Raw PE normalization layer type # self.pass_as_var = pecfg.pass_as_var # Pass PE also as a separate variable From 2eb3a10b8ea8731f58af2b4d496de03e733dbf2b Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Tue, 24 Mar 2026 09:26:09 -0400 Subject: [PATCH 32/62] settle final merge conflicts Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- .../datasets/hetero_powergrid_datamodule.py | 8 + .../datasets/powergrid_datamodule.py | 230 ----------- .../tasks/feature_reconstruction_task.py | 356 ------------------ 3 files changed, 8 insertions(+), 586 deletions(-) delete mode 100644 gridfm_graphkit/datasets/powergrid_datamodule.py delete mode 100644 gridfm_graphkit/tasks/feature_reconstruction_task.py diff --git a/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py b/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py index 4474c2e..e6a4cfd 100644 --- a/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py +++ b/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py @@ -163,6 +163,14 @@ def setup(self, stage: str): dataset.transform = pe_transform else: dataset.transform = T.Compose([pe_transform, dataset.transform]) + if ('posenc_RWSE' in self.args.data) and self.args.data.posenc_RWSE.enable: + pe_transform = ComputePosencStat(pe_types=['RWSE'], + cfg=self.args.data + ) + if dataset.transform is None: + dataset.transform = pe_transform + else: + dataset.transform = T.Compose([pe_transform, dataset.transform]) self.datasets.append(dataset) diff --git a/gridfm_graphkit/datasets/powergrid_datamodule.py b/gridfm_graphkit/datasets/powergrid_datamodule.py deleted file mode 100644 index 740ff51..0000000 --- a/gridfm_graphkit/datasets/powergrid_datamodule.py +++ /dev/null @@ -1,230 +0,0 @@ -import torch -from torch_geometric.loader import DataLoader -from torch.utils.data import ConcatDataset -from torch.utils.data import Subset -import torch.distributed as dist -from gridfm_graphkit.io.param_handler import ( - NestedNamespace, - load_normalizer, - get_transform, -) -from gridfm_graphkit.datasets.utils import split_dataset -from gridfm_graphkit.datasets.powergrid_dataset import GridDatasetDisk - -from gridfm_graphkit.datasets.posenc_stats import ComputePosencStat - -import torch_geometric.transforms as T - -import numpy as np -import random -import warnings -import os -import lightning as L - - -class LitGridDataModule(L.LightningDataModule): - """ - PyTorch Lightning DataModule for power grid datasets. - - This datamodule handles loading, preprocessing, splitting, and batching - of power grid graph datasets (`GridDatasetDisk`) for training, validation, - testing, and prediction. It ensures reproducibility through fixed seeds. - - Args: - args (NestedNamespace): Experiment configuration. - data_dir (str, optional): Root directory for datasets. Defaults to "./data". - - Attributes: - batch_size (int): Batch size for all dataloaders. From ``args.training.batch_size`` - node_normalizers (list): List of node feature normalizers, one per dataset. - edge_normalizers (list): List of edge feature normalizers, one per dataset. - datasets (list): Original datasets for each network. - train_datasets (list): Train splits for each network. - val_datasets (list): Validation splits for each network. - test_datasets (list): Test splits for each network. - train_dataset_multi (ConcatDataset): Concatenated train datasets for multi-network training. - val_dataset_multi (ConcatDataset): Concatenated validation datasets for multi-network validation. - _is_setup_done (bool): Tracks whether `setup` has been executed to avoid repeated processing. - - Methods: - setup(stage): - Load and preprocess datasets, split into train/val/test, and store normalizers. - Handles distributed preprocessing safely. - train_dataloader(): - Returns a DataLoader for concatenated training datasets. - val_dataloader(): - Returns a DataLoader for concatenated validation datasets. - test_dataloader(): - Returns a list of DataLoaders, one per test dataset. - predict_dataloader(): - Returns a list of DataLoaders, one per test dataset for prediction. - - Notes: - - Preprocessing is only performed on rank 0 in distributed settings. - - Subsets and splits are deterministic based on the provided random seed. - - Normalizers are loaded for each network independently. - - Test and predict dataloaders are returned as lists, one per dataset. - - Example: - ```python - from gridfm_graphkit.datasets.powergrid_datamodule import LitGridDataModule - from gridfm_graphkit.io.param_handler import NestedNamespace - import yaml - - with open("config/config.yaml") as f: - base_config = yaml.safe_load(f) - args = NestedNamespace(**base_config) - - datamodule = LitGridDataModule(args, data_dir="./data") - - datamodule.setup("fit") - train_loader = datamodule.train_dataloader() - ``` - """ - - def __init__(self, args: NestedNamespace, data_dir: str = "./data"): - super().__init__() - self.data_dir = data_dir - self.batch_size = int(args.training.batch_size) - self.args = args - self.node_normalizers = [] - self.edge_normalizers = [] - self.datasets = [] - self.train_datasets = [] - self.val_datasets = [] - self.test_datasets = [] - self._is_setup_done = False - - def setup(self, stage: str): - if self._is_setup_done: - print(f"Setup already done for stage={stage}, skipping...") - return - - for i, network in enumerate(self.args.data.networks): - node_normalizer, edge_normalizer = load_normalizer(args=self.args) - self.node_normalizers.append(node_normalizer) - self.edge_normalizers.append(edge_normalizer) - - # Create torch dataset and split - data_path_network = os.path.join(self.data_dir, network) - - # Run preprocessing only on rank 0 - if dist.is_available() and dist.is_initialized() and dist.get_rank() == 0: - print(f"Pre-processing of {network} dataset on rank 0") - _ = GridDatasetDisk( # just to trigger processing - root=data_path_network, - norm_method=self.args.data.normalization, - node_normalizer=node_normalizer, - edge_normalizer=edge_normalizer, - pe_dim=self.args.model.pe_dim, - mask_dim=self.args.data.mask_dim, - transform=get_transform(args=self.args), - ) - - # All ranks wait here until processing is done - if torch.distributed.is_available() and torch.distributed.is_initialized(): - torch.distributed.barrier() - - dataset = GridDatasetDisk( - root=data_path_network, - norm_method=self.args.data.normalization, - node_normalizer=node_normalizer, - edge_normalizer=edge_normalizer, - pe_dim=self.args.model.pe_dim, - mask_dim=self.args.data.mask_dim, - transform=get_transform(args=self.args), - ) - - if ('posenc_RRWP' in self.args.data) and self.args.data.posenc_RRWP.enable: - pe_transform = ComputePosencStat(pe_types=['RRWP'], - cfg=self.args.data - ) - if dataset.transform is None: - dataset.transform = pe_transform - else: - dataset.transform = T.Compose([pe_transform, dataset.transform]) - if ('posenc_RWSE' in self.args.data) and self.args.data.posenc_RWSE.enable: - pe_transform = ComputePosencStat(pe_types=['RWSE'], - cfg=self.args.data - ) - if dataset.transform is None: - dataset.transform = pe_transform - else: - dataset.transform = T.Compose([pe_transform, dataset.transform]) - - self.datasets.append(dataset) - - num_scenarios = self.args.data.scenarios[i] - if num_scenarios > len(dataset): - warnings.warn( - f"Requested number of scenarios ({num_scenarios}) exceeds dataset size ({len(dataset)}). " - "Using the full dataset instead.", - ) - num_scenarios = len(dataset) - - # Create a subset - all_indices = list(range(len(dataset))) - # Random seed set before every shuffle for reproducibility in case the power grid datasets are analyzed in a different order - random.seed(self.args.seed) - random.shuffle(all_indices) - subset_indices = all_indices[:num_scenarios] - dataset = Subset(dataset, subset_indices) - - # Random seed set before every split, same as above - np.random.seed(self.args.seed) - train_dataset, val_dataset, test_dataset = split_dataset( - dataset, - self.data_dir, - self.args.data.val_ratio, - self.args.data.test_ratio, - ) - - self.train_datasets.append(train_dataset) - self.val_datasets.append(val_dataset) - self.test_datasets.append(test_dataset) - - self.train_dataset_multi = ConcatDataset(self.train_datasets) - self.val_dataset_multi = ConcatDataset(self.val_datasets) - self._is_setup_done = True - - def train_dataloader(self): - return DataLoader( - self.train_dataset_multi, - batch_size=self.batch_size, - shuffle=True, - num_workers=self.args.data.workers, - pin_memory=True, - ) - - def val_dataloader(self): - return DataLoader( - self.val_dataset_multi, - batch_size=self.batch_size, - shuffle=False, - num_workers=self.args.data.workers, - pin_memory=True, - ) - - def test_dataloader(self): - return [ - DataLoader( - i, - batch_size=self.batch_size, - shuffle=False, - num_workers=self.args.data.workers, - pin_memory=True, - ) - for i in self.test_datasets - ] - - def predict_dataloader(self): - return [ - DataLoader( - i, - batch_size=self.batch_size, - shuffle=False, - num_workers=self.args.data.workers, - pin_memory=True, - ) - for i in self.test_datasets - ] diff --git a/gridfm_graphkit/tasks/feature_reconstruction_task.py b/gridfm_graphkit/tasks/feature_reconstruction_task.py deleted file mode 100644 index aa3370c..0000000 --- a/gridfm_graphkit/tasks/feature_reconstruction_task.py +++ /dev/null @@ -1,356 +0,0 @@ -import torch -from torch.optim.lr_scheduler import ReduceLROnPlateau -import lightning as L -from pytorch_lightning.utilities import rank_zero_only -import numpy as np -import os -import pandas as pd -import copy - -from lightning.pytorch.loggers import MLFlowLogger -from gridfm_graphkit.io.param_handler import load_model, get_loss_function -import torch.nn.functional as F -from gridfm_graphkit.datasets.globals import PQ, PV, REF, PD, QD, PG, QG, VM, VA - - -class FeatureReconstructionTask(L.LightningModule): - """ - PyTorch Lightning task for node feature reconstruction on power grid graphs. - - This task wraps a GridFM model inside a LightningModule and defines the full - training, validation, testing, and prediction logic. It is designed to - reconstruct masked node features from graph-structured input data, using - datasets and normalizers provided by `gridfm-graphkit`. - - Args: - args (NestedNamespace): Experiment configuration. Expected fields include `training.batch_size`, `optimizer.*`, etc. - node_normalizers (list): One normalizer per dataset to (de)normalize node features. - edge_normalizers (list): One normalizer per dataset to (de)normalize edge features. - - Attributes: - model (torch.nn.Module): model loaded via `load_model`. - loss_fn (callable): Loss function resolved from configuration. - batch_size (int): Training batch size. From ``args.training.batch_size`` - node_normalizers (list): Dataset-wise node feature normalizers. - edge_normalizers (list): Dataset-wise edge feature normalizers. - - Methods: - forward(x, pe, edge_index, edge_attr, batch, mask=None): - Forward pass with optional feature masking. - training_step(batch): - One training step: computes loss, logs metrics, returns loss. - validation_step(batch, batch_idx): - One validation step: computes losses and logs metrics. - test_step(batch, batch_idx, dataloader_idx=0): - Evaluate on test data, compute per-node-type MSEs, and log per-dataset metrics. - predict_step(batch, batch_idx, dataloader_idx=0): - Run inference and return denormalized outputs + node masks. - configure_optimizers(): - Setup Adam optimizer and ReduceLROnPlateau scheduler. - on_fit_start(): - Save normalization statistics at the beginning of training. - on_test_end(): - Collect test metrics across datasets and export summary CSV reports. - - Notes: - - Node types are distinguished using the global constants (`PQ`, `PV`, `REF`). - - The datamodule must provide `batch.mask` for masking node features. - - Test metrics include per-node-type RMSE for [Pd, Qd, Pg, Qg, Vm, Va]. - - Reports are saved under `/test/.csv`. - - Example: - ```python - model = FeatureReconstructionTask(args, node_normalizers, edge_normalizers) - output = model(batch.x, batch.pe, batch.edge_index, batch.edge_attr, batch.batch) - ``` - """ - - def __init__(self, args, node_normalizers, edge_normalizers): - super().__init__() - self.model = load_model(args=args) - self.args = args - self.loss_fn = get_loss_function(args) - self.batch_size = int(args.training.batch_size) - self.node_normalizers = node_normalizers - self.edge_normalizers = edge_normalizers - self.save_hyperparameters() - - def forward(self, batch): - if batch.mask is not None: - # print('xxxx',batch.x.min(), batch.x.max()) - # print('yyyyy',batch.y.min(), batch.y.max()) - mask_value_expanded = self.model.mask_value.expand(batch.x.shape[0], -1) - batch.x[:, : batch.mask.shape[1]][batch.mask] = mask_value_expanded[batch.mask] - return self.model(batch) - - @rank_zero_only - def on_fit_start(self): - # Determine save path - if isinstance(self.logger, MLFlowLogger): - log_dir = os.path.join( - self.logger.save_dir, - self.logger.experiment_id, - self.logger.run_id, - "artifacts", - "stats", - ) - else: - log_dir = os.path.join(self.logger.save_dir, "stats") - - os.makedirs(log_dir, exist_ok=True) - log_stats_path = os.path.join(log_dir, "normalization_stats.txt") - - # Collect normalization stats - with open(log_stats_path, "w") as log_file: - for i, normalizer in enumerate(self.node_normalizers): - log_file.write( - f"Node Normalizer {self.args.data.networks[i]} stats:\n{normalizer.get_stats()}\n\n", - ) - - for i, normalizer in enumerate(self.edge_normalizers): - log_file.write( - f"Edge Normalizer {self.args.data.networks[i]} stats:\n{normalizer.get_stats()}\n\n", - ) - - def shared_step(self, batch): - output = self.forward( - batch - ) - - loss_dict = self.loss_fn( - output, - batch.y, - batch.edge_index, - batch.edge_attr, - batch.mask, - ) - return output, loss_dict - - def training_step(self, batch): - _, loss_dict = self.shared_step(batch) - current_lr = self.optimizer.param_groups[0]["lr"] - metrics = {} - metrics["Training Loss"] = loss_dict["loss"].detach() - metrics["Learning Rate"] = current_lr - for metric, value in metrics.items(): - self.log( - metric, - value, - batch_size=batch.num_graphs, - sync_dist=True, - on_epoch=True, - prog_bar=True, - logger=True, - on_step=False, - ) - - return loss_dict["loss"] - - def validation_step(self, batch, batch_idx): - _, loss_dict = self.shared_step(batch) - loss_dict["loss"] = loss_dict["loss"].detach() - for metric, value in loss_dict.items(): - metric_name = f"Validation {metric}" - self.log( - metric_name, - value, - batch_size=batch.num_graphs, - sync_dist=True, - on_epoch=True, - prog_bar=True, - logger=True, - on_step=False, - ) - - return loss_dict["loss"] - - def test_step(self, batch, batch_idx, dataloader_idx=0): - output, loss_dict = self.shared_step(copy.deepcopy(batch)) - - dataset_name = self.args.data.networks[dataloader_idx] - - output_denorm = self.node_normalizers[dataloader_idx].inverse_transform(output) - target_denorm = self.node_normalizers[dataloader_idx].inverse_transform(batch.y) - - mask_PQ = batch.x[:, PQ] == 1 - mask_PV = batch.x[:, PV] == 1 - mask_REF = batch.x[:, REF] == 1 - - mse_PQ = F.mse_loss( - output_denorm[mask_PQ], - target_denorm[mask_PQ], - reduction="none", - ) - mse_PV = F.mse_loss( - output_denorm[mask_PV], - target_denorm[mask_PV], - reduction="none", - ) - mse_REF = F.mse_loss( - output_denorm[mask_REF], - target_denorm[mask_REF], - reduction="none", - ) - - mse_PQ = mse_PQ.mean(dim=0) - mse_PV = mse_PV.mean(dim=0) - mse_REF = mse_REF.mean(dim=0) - - loss_dict["MSE PQ nodes - PD"] = mse_PQ[PD] - loss_dict["MSE PV nodes - PD"] = mse_PV[PD] - loss_dict["MSE REF nodes - PD"] = mse_REF[PD] - - loss_dict["MSE PQ nodes - QD"] = mse_PQ[QD] - loss_dict["MSE PV nodes - QD"] = mse_PV[QD] - loss_dict["MSE REF nodes - QD"] = mse_REF[QD] - - loss_dict["MSE PQ nodes - PG"] = mse_PQ[PG] - loss_dict["MSE PV nodes - PG"] = mse_PV[PG] - loss_dict["MSE REF nodes - PG"] = mse_REF[PG] - - loss_dict["MSE PQ nodes - QG"] = mse_PQ[QG] - loss_dict["MSE PV nodes - QG"] = mse_PV[QG] - loss_dict["MSE REF nodes - QG"] = mse_REF[QG] - - loss_dict["MSE PQ nodes - VM"] = mse_PQ[VM] - loss_dict["MSE PV nodes - VM"] = mse_PV[VM] - loss_dict["MSE REF nodes - VM"] = mse_REF[VM] - - loss_dict["MSE PQ nodes - VA"] = mse_PQ[VA] - loss_dict["MSE PV nodes - VA"] = mse_PV[VA] - loss_dict["MSE REF nodes - VA"] = mse_REF[VA] - - loss_dict["Test loss"] = loss_dict.pop("loss").detach() - for metric, value in loss_dict.items(): - metric_name = f"{dataset_name}/{metric}" - if "p.u." in metric: - # Denormalize metrics expressed in p.u. - value *= self.node_normalizers[dataloader_idx].baseMVA - metric_name = metric_name.replace("in p.u.", "").strip() - self.log( - metric_name, - value, - batch_size=batch.num_graphs, - add_dataloader_idx=False, - sync_dist=True, - logger=False, - ) - return - - def predict_step(self, batch, batch_idx, dataloader_idx=0): - output, _ = self.shared_step(batch) - output_denorm = self.node_normalizers[dataloader_idx].inverse_transform(output) - - # Count buses and generate per-node scenario_id - bus_counts = batch.batch.unique(return_counts=True)[1] - scenario_ids = batch.scenario_id # shape: [num_graphs] - scenario_per_node = torch.cat( - [ - torch.full((count,), sid, dtype=torch.int32) - for count, sid in zip(bus_counts, scenario_ids) - ], - ) - - bus_numbers = np.concatenate([np.arange(count.item()) for count in bus_counts]) - - return { - "output": output_denorm.cpu().numpy(), - "scenario_id": scenario_per_node, - "bus_number": bus_numbers, - } - - @rank_zero_only - def on_test_end(self): - if isinstance(self.logger, MLFlowLogger): - artifact_dir = os.path.join( - self.logger.save_dir, - self.logger.experiment_id, - self.logger.run_id, - "artifacts", - ) - else: - artifact_dir = self.logger.save_dir - - final_metrics = self.trainer.callback_metrics - grouped_metrics = {} - - for full_key, value in final_metrics.items(): - try: - value = value.item() - except AttributeError: - pass - - if "/" in full_key: - dataset_name, metric = full_key.split("/", 1) - if dataset_name not in grouped_metrics: - grouped_metrics[dataset_name] = {} - grouped_metrics[dataset_name][metric] = value - - for dataset, metrics in grouped_metrics.items(): - rmse_PQ = [ - metrics.get(f"MSE PQ nodes - {label}", float("nan")) ** 0.5 - for label in ["PD", "QD", "PG", "QG", "VM", "VA"] - ] - rmse_PV = [ - metrics.get(f"MSE PV nodes - {label}", float("nan")) ** 0.5 - for label in ["PD", "QD", "PG", "QG", "VM", "VA"] - ] - rmse_REF = [ - metrics.get(f"MSE REF nodes - {label}", float("nan")) ** 0.5 - for label in ["PD", "QD", "PG", "QG", "VM", "VA"] - ] - - avg_active_res = metrics.get("Active Power Loss", " ") - avg_reactive_res = metrics.get("Reactive Power Loss", " ") - - data = { - "Metric": [ - "RMSE-PQ", - "RMSE-PV", - "RMSE-REF", - "Avg. active res. (MW)", - "Avg. reactive res. (MVar)", - ], - "Pd (MW)": [ - rmse_PQ[0], - rmse_PV[0], - rmse_REF[0], - avg_active_res, - avg_reactive_res, - ], - "Qd (MVar)": [rmse_PQ[1], rmse_PV[1], rmse_REF[1], " ", " "], - "Pg (MW)": [rmse_PQ[2], rmse_PV[2], rmse_REF[2], " ", " "], - "Qg (MVar)": [rmse_PQ[3], rmse_PV[3], rmse_REF[3], " ", " "], - "Vm (p.u.)": [rmse_PQ[4], rmse_PV[4], rmse_REF[4], " ", " "], - "Va (degree)": [rmse_PQ[5], rmse_PV[5], rmse_REF[5], " ", " "], - } - - df = pd.DataFrame(data) - - test_dir = os.path.join(artifact_dir, "test") - os.makedirs(test_dir, exist_ok=True) - csv_path = os.path.join(test_dir, f"{dataset}.csv") - df.to_csv(csv_path, index=False) - - def configure_optimizers(self): - self.optimizer = torch.optim.Adam( - self.model.parameters(), - lr=self.args.optimizer.learning_rate, - betas=(self.args.optimizer.beta1, self.args.optimizer.beta2), - ) - - self.scheduler = ReduceLROnPlateau( - self.optimizer, - mode="min", - factor=self.args.optimizer.lr_decay, - patience=self.args.optimizer.lr_patience, - ) - config_optim = { - "optimizer": self.optimizer, - "lr_scheduler": { - "scheduler": self.scheduler, - "monitor": "Validation loss", - "reduce_on_plateau": True, - }, - } - return config_optim From 6d89bba750c488b7e5e61561f4f491ed95b60da3 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Tue, 24 Mar 2026 13:03:39 -0400 Subject: [PATCH 33/62] connect grit and encoders with hetero-adapter Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/datasets/posenc_stats.py | 33 ++++++- .../models/gnn_heterogeneous_gns.py | 18 ++-- gridfm_graphkit/models/grit_transformer.py | 89 ++++++++++++++++++- gridfm_graphkit/tasks/reconstruction_tasks.py | 11 +-- 4 files changed, 134 insertions(+), 17 deletions(-) diff --git a/gridfm_graphkit/datasets/posenc_stats.py b/gridfm_graphkit/datasets/posenc_stats.py index e12d9f1..5263b48 100644 --- a/gridfm_graphkit/datasets/posenc_stats.py +++ b/gridfm_graphkit/datasets/posenc_stats.py @@ -12,7 +12,7 @@ from gridfm_graphkit.datasets.rrwp import add_full_rrwp from torch_geometric.transforms import BaseTransform -from torch_geometric.data import Data +from torch_geometric.data import Data, HeteroData from typing import Any from torch_geometric.utils.num_nodes import maybe_num_nodes @@ -129,9 +129,38 @@ def __init__(self, pe_types, cfg): def forward(self, data: Any) -> Any: pass - def __call__(self, data: Data) -> Data: + def __call__(self, data) -> Data: + if isinstance(data, HeteroData): + return self._call_hetero(data) + data = compute_posenc_stats(data, pe_types=self.pe_types, cfg=self.cfg ) return data + + def _call_hetero(self, data: HeteroData) -> HeteroData: + """Compute PE on the bus-only subgraph and store results on data['bus'].""" + bus_data = Data( + x=data["bus"].x, + edge_index=data["bus", "connects", "bus"].edge_index, + num_nodes=data["bus"].num_nodes, + ) + if hasattr(data["bus", "connects", "bus"], "edge_weight"): + bus_data.edge_weight = data["bus", "connects", "bus"].edge_weight + + bus_data = compute_posenc_stats( + bus_data, pe_types=self.pe_types, cfg=self.cfg, + ) + + # Copy computed PE attributes back onto the HeteroData bus store + pe_attrs = [ + "pestat_RWSE", # RWSE + "rrwp", "rrwp_index", "rrwp_val", # RRWP + "log_deg", "deg", # degree info from RRWP + ] + for attr in pe_attrs: + if hasattr(bus_data, attr): + data["bus"][attr] = getattr(bus_data, attr) + + return data diff --git a/gridfm_graphkit/models/gnn_heterogeneous_gns.py b/gridfm_graphkit/models/gnn_heterogeneous_gns.py index 1073560..e274a1c 100644 --- a/gridfm_graphkit/models/gnn_heterogeneous_gns.py +++ b/gridfm_graphkit/models/gnn_heterogeneous_gns.py @@ -146,14 +146,20 @@ def __init__(self, args) -> None: # container for monitoring residual norms per layer and type self.layer_residuals = {} - def forward(self, x_dict, edge_index_dict, edge_attr_dict, mask_dict): + def forward(self, batch): """ - x_dict: {"bus": Tensor[num_bus, bus_feat], "gen": Tensor[num_gen, gen_feat]} - edge_index_dict: keys like ("bus","connects","bus"), ("gen","connected_to","bus"), ("bus","connected_to","gen") - edge_attr_dict: same keys -> edge attributes (bus-bus requires G,B) - batch_dict: dict mapping node types to batch tensors (if using batching). Not used heavily here but kept for API parity. - mask: optional mask per node (applies when computing residuals) + Accepts a PyG HeteroData batch and extracts the required tensors. + + batch: HeteroData/Batch containing: + x_dict: {"bus": Tensor[num_bus, bus_feat], "gen": Tensor[num_gen, gen_feat]} + edge_index_dict: keys like ("bus","connects","bus"), ("gen","connected_to","bus"), ("bus","connected_to","gen") + edge_attr_dict: same keys -> edge attributes (bus-bus requires G,B) + mask_dict: dict mapping node/bus types to mask tensors """ + x_dict = batch.x_dict + edge_index_dict = batch.edge_index_dict + edge_attr_dict = batch.edge_attr_dict + mask_dict = batch.mask_dict self.layer_residuals = {} diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index ab0d51b..17b1539 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -1,6 +1,7 @@ from gridfm_graphkit.io.registries import MODELS_REGISTRY import torch from torch import nn +from torch_geometric.data import Data from gridfm_graphkit.models.rrwp_encoder import RRWPLinearNodeEncoder, RRWPLinearEdgeEncoder from gridfm_graphkit.models.grit_layer import GritTransformerLayer @@ -118,7 +119,6 @@ def forward(self, batch): return pred -@MODELS_REGISTRY.register("GRIT") class GritTransformer(torch.nn.Module): """ The GritTransformer (Graph Inductive Bias Transformer) from @@ -214,3 +214,90 @@ def forward(self, batch): batch = module(batch) return batch + + +@MODELS_REGISTRY.register("GRIT") +class GritHeteroAdapter(torch.nn.Module): + """Adapter that enables the homogeneous GRIT transformer to operate on + heterogeneous power-grid graphs. + + Extracts the bus-only homogeneous subgraph using PyG's native HeteroData + accessors, runs it through the GRIT encoder and transformer layers, and + produces per-node-type predictions. Generator output comes from a + lightweight standalone MLP (generators are not seen by the transformer). + + Returns: + dict: ``{"bus": Tensor[num_bus, output_bus_dim], + "gen": Tensor[num_gen, output_gen_dim]}`` + """ + + def __init__(self, args): + super().__init__() + + dim_inner = args.model.hidden_size + output_bus_dim = args.model.output_bus_dim + output_gen_dim = args.model.output_gen_dim + input_gen_dim = args.model.input_gen_dim + + # Ensure config keys expected by GritTransformer are present. + # input_dim = bus feature dimension (used by FeatureEncoder) + # output_dim = bus output dimension (used by the unused GraphHead) + if not hasattr(args.model, "input_dim"): + args.model.input_dim = args.model.input_bus_dim + if not hasattr(args.model, "output_dim"): + args.model.output_dim = output_bus_dim + + # The original homogeneous GRIT + # (encoder + optional PE encoders + transformer layers + GraphHead) + self.grit = GritTransformer(args) + + # Per-node-type output heads (replace GraphHead for hetero output) + self.bus_head = nn.Sequential( + nn.Linear(dim_inner, dim_inner), + nn.LeakyReLU(), + nn.Linear(dim_inner, output_bus_dim), + ) + self.gen_head = nn.Sequential( + nn.Linear(input_gen_dim, dim_inner), + nn.LeakyReLU(), + nn.Linear(dim_inner, output_gen_dim), + ) + + def forward(self, batch): + """Forward pass on a heterogeneous power-grid batch. + + Args: + batch: A batched ``HeteroData`` with node types ``"bus"`` and + ``"gen"``, and edge type ``("bus", "connects", "bus")``. + + Returns: + dict with keys ``"bus"`` and ``"gen"``, each mapping to the + predicted output features. + """ + # --- Extract bus-only homogeneous subgraph --- + homo = Data( + x=batch["bus"].x, + y=batch["bus"].y, + edge_index=batch["bus", "connects", "bus"].edge_index, + edge_attr=batch["bus", "connects", "bus"].edge_attr, + batch=batch["bus"].batch, + ) + + # Forward positional-encoding attributes if present + for attr in ("pestat_RWSE", "rrwp", "rrwp_index", "rrwp_val", "log_deg", "deg"): + if hasattr(batch["bus"], attr): + setattr(homo, attr, getattr(batch["bus"], attr)) + + # --- Run GRIT encoder + PE encoders + transformer layers --- + homo = self.grit.encoder(homo) + if hasattr(self.grit, "rrwp_abs_encoder"): + homo = self.grit.rrwp_abs_encoder(homo) + if hasattr(self.grit, "rrwp_rel_encoder"): + homo = self.grit.rrwp_rel_encoder(homo) + homo = self.grit.layers(homo) + + # --- Per-type decoding --- + bus_out = self.bus_head(homo.x) + gen_out = self.gen_head(batch["gen"].x) + + return {"bus": bus_out, "gen": gen_out} diff --git a/gridfm_graphkit/tasks/reconstruction_tasks.py b/gridfm_graphkit/tasks/reconstruction_tasks.py index 8742646..43e243c 100644 --- a/gridfm_graphkit/tasks/reconstruction_tasks.py +++ b/gridfm_graphkit/tasks/reconstruction_tasks.py @@ -39,16 +39,11 @@ def __init__(self, args, data_normalizers): self.batch_size = int(args.training.batch_size) self.test_outputs = {i: [] for i in range(len(args.data.networks))} - def forward(self, x_dict, edge_index_dict, edge_attr_dict, mask_dict): - return self.model(x_dict, edge_index_dict, edge_attr_dict, mask_dict) + def forward(self, batch): + return self.model(batch) def shared_step(self, batch): - output = self.forward( - x_dict=batch.x_dict, - edge_index_dict=batch.edge_index_dict, - edge_attr_dict=batch.edge_attr_dict, - mask_dict=batch.mask_dict, - ) + output = self.forward(batch) loss_dict = self.loss_fn( output, From 926dff5f89e5de06ef6c0bd180dfe75ad03f119b Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Tue, 24 Mar 2026 13:47:03 -0400 Subject: [PATCH 34/62] flow over and update PBE loss Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/training/loss.py | 116 +++++++++++++++++++++++++++++++ 1 file changed, 116 insertions(+) diff --git a/gridfm_graphkit/training/loss.py b/gridfm_graphkit/training/loss.py index d253d2b..02df3bc 100644 --- a/gridfm_graphkit/training/loss.py +++ b/gridfm_graphkit/training/loss.py @@ -4,6 +4,7 @@ from abc import ABC, abstractmethod from gridfm_graphkit.io.registries import LOSS_REGISTRY from torch_scatter import scatter_add +from torch_geometric.utils import to_torch_coo_tensor from gridfm_graphkit.datasets.globals import ( # Bus feature indices @@ -19,6 +20,11 @@ PG_OUT, # Generator feature indices PG_H, + # Edge feature indices + YFF_TT_R, + YFF_TT_I, + YFT_TF_R, + YFT_TF_I, ) @@ -322,3 +328,113 @@ def forward( f"MSE loss {self.dim}": mse_loss.detach(), f"MAE loss {self.dim}": mae_loss.detach(), } + +@LOSS_REGISTRY.register("PBE") +class PBELoss(BaseLoss): + """ + Loss based on the Power Balance Equations. + + Adapted for the heterogeneous graph convention: predictions and targets + are passed as dicts (``{"bus": …, "gen": …}``). Generator active power + is aggregated onto bus nodes via the ``(gen, connected_to, bus)`` edge + index before computing the power balance. + """ + + def __init__(self, loss_args, args): + super(PBELoss, self).__init__() + self.visualization = getattr(loss_args, "visualization", False) + + def forward( + self, + pred_dict, + target_dict, + edge_index_dict, + edge_attr_dict, + mask_dict, + model=None, + ): + pred_bus = pred_dict["bus"] # [N_bus, output_bus_dim] + target_bus = target_dict["bus"] # [N_bus, bus_feat_dim] + num_bus = target_bus.size(0) + + bus_edge_index = edge_index_dict[("bus", "connects", "bus")] + bus_edge_attr = edge_attr_dict[("bus", "connects", "bus")] + mask_bus = mask_dict["bus"] + + # --- Voltage: use prediction where masked, target where known --- + Vm_pred = pred_bus[:, VM_OUT] + Va_pred = pred_bus[:, VA_OUT] + Vm_target = target_bus[:, VM_H] + Va_target = target_bus[:, VA_H] + + mask_Vm = mask_bus[:, VM_H] + mask_Va = mask_bus[:, VA_H] + + V_m = torch.where(mask_Vm, Vm_pred, Vm_target) + V_a = torch.where(mask_Va, Va_pred, Va_target) + + # Complex voltage + V = V_m * torch.exp(1j * V_a) + V_conj = torch.conj(V) + + # --- Admittance matrix from bus-bus edge attrs --- + # Use Yff (diagonal-block) real/imag as the admittance entries + edge_complex = bus_edge_attr[:, YFF_TT_R] + 1j * bus_edge_attr[:, YFF_TT_I] + + Y_bus_sparse = to_torch_coo_tensor( + bus_edge_index, + edge_complex, + size=(num_bus, num_bus), + ) + Y_bus_conj = torch.conj(Y_bus_sparse) + + # Complex power injection: S_inj = diag(V) * conj(Y) * conj(V) + S_injection = torch.diag(V) @ Y_bus_conj @ V_conj + + # --- Net power from predictions/targets --- + # Pg: aggregate generator predictions onto buses + gen_to_bus_ei = edge_index_dict[("gen", "connected_to", "bus")] + Pg_per_bus = scatter_add( + pred_dict["gen"].squeeze(-1), + gen_to_bus_ei[1], + dim=0, + dim_size=num_bus, + ) + + Pd = target_bus[:, PD_H] + Qd = target_bus[:, QD_H] + + # Qg: use prediction if the model predicts it, else use target + if pred_bus.size(1) > QG_OUT: + Qg = torch.where(mask_bus[:, QG_H], pred_bus[:, QG_OUT], target_bus[:, QG_H]) + else: + Qg = target_bus[:, QG_H] + + net_P = Pg_per_bus - Pd + net_Q = Qg - Qd + S_net = net_P + 1j * net_Q + + # --- Loss --- + loss = torch.mean(torch.abs(S_net - S_injection)) + + real_loss = torch.mean( + torch.abs(torch.real(S_net - S_injection)), + ) + imag_loss = torch.mean( + torch.abs(torch.imag(S_net - S_injection)), + ) + + result = { + "loss": loss, + "Power loss in p.u.": loss.detach(), + "Active Power Loss in p.u.": real_loss.detach(), + "Reactive Power Loss in p.u.": imag_loss.detach(), + } + if self.visualization: + result["Nodal Active Power Loss in p.u."] = torch.abs( + torch.real(S_net - S_injection), + ) + result["Nodal Reactive Power Loss in p.u."] = torch.abs( + torch.imag(S_net - S_injection), + ) + return result From 9bcb0d1f73dede61aeeead8f0b22967a2647aaa4 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Tue, 24 Mar 2026 13:47:47 -0400 Subject: [PATCH 35/62] added sample config Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/GRIT_PF_datakit_case14.yaml | 94 +++++++++++++++++++++ 1 file changed, 94 insertions(+) create mode 100644 examples/config/GRIT_PF_datakit_case14.yaml diff --git a/examples/config/GRIT_PF_datakit_case14.yaml b/examples/config/GRIT_PF_datakit_case14.yaml new file mode 100644 index 0000000..d119e13 --- /dev/null +++ b/examples/config/GRIT_PF_datakit_case14.yaml @@ -0,0 +1,94 @@ +callbacks: + patience: 100 + tol: 0 +task: + task_name: PowerFlow +data: + baseMVA: 100 + mask_value: 0.0 + normalization: HeteroDataMVANormalizer + networks: + - case14_ieee + scenarios: + - 5000 + test_ratio: 0.1 + val_ratio: 0.1 + workers: 4 + posenc_RRWP: + enable: false + ksteps: 21 + posenc_RWSE: + enable: true + kernel: + times: 21 +model: + attention_head: 8 + dropout: 0.1 + # edge_dim must match the bus-bus edge feature count after transforms + # (P_E, Q_E, YFF_TT_R, YFF_TT_I, YFT_TF_R, YFT_TF_I, TAP, ANG_MIN, ANG_MAX, RATE_A) + edge_dim: 10 + hidden_size: 116 + # input_dim = bus feature count (used by GRIT core FeatureEncoder) + input_dim: 15 + # Hetero adapter head dimensions + input_bus_dim: 15 + input_gen_dim: 6 + output_bus_dim: 2 + output_gen_dim: 1 + num_layers: 10 + type: GRIT + act: relu + encoder: + node_encoder: true + edge_encoder: true + node_encoder_name: RWSE + node_encoder_bn: true + edge_encoder_bn: true + posenc_RWSE: + kernel: + times: 21 + pe_dim: 20 + raw_norm_type: batchnorm + gt: + layer_type: GritTransformer + dim_hidden: 116 # must match hidden_size + layer_norm: false + batch_norm: true + update_e: true + attn_dropout: 0.2 + attn: + clamp: 5. + act: relu + full_attn: true + edge_enhance: true + O_e: true + norm_e: true + signed_sqrt: true + bn_momentum: 0.1 + bn_no_runner: false +optimizer: + beta1: 0.9 + beta2: 0.999 + learning_rate: 0.0001 + lr_decay: 0.7 + lr_patience: 10 +seed: 0 +training: + batch_size: 8 + epochs: 500 + loss_weights: + - 0.01 + - 0.09 + - 0.9 + losses: + - PBE + - MaskedGenMSE + - MaskedBusMSE + loss_args: + - {} + - {} + - {} + accelerator: auto + devices: auto + strategy: auto +verbose: true From b68ae5e12d9d92a8495efb960983040d47efce85 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Tue, 24 Mar 2026 13:52:37 -0400 Subject: [PATCH 36/62] update project toml Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 2b6c523..100b875 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,11 +44,14 @@ dependencies = [ "nbformat>=5.10.4", "networkx>=3.4.2", "numpy>=2.2.6", + "opt-einsum>=3.3.0", "pandas>=2.3.0", "plotly>=6.1.2", "pyyaml>=6.0.2", "torch>=2.7.1,<2.9", "torch-geometric>=2.6.1", + "torch-scatter>=2.1.2", + "torch-sparse>=0.6.18", "torchaudio>=2.7.1", "torchvision>=0.22.1", "lightning", From 52942fa6aef4129ab9030b13fb685df3a5273537 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Tue, 24 Mar 2026 14:28:02 -0400 Subject: [PATCH 37/62] simplify configuration file Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/GRIT_PF_datakit_case14.yaml | 5 ++--- gridfm_graphkit/models/grit_transformer.py | 18 ++++++++++++++++++ 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/examples/config/GRIT_PF_datakit_case14.yaml b/examples/config/GRIT_PF_datakit_case14.yaml index d119e13..d1ecdca 100644 --- a/examples/config/GRIT_PF_datakit_case14.yaml +++ b/examples/config/GRIT_PF_datakit_case14.yaml @@ -45,13 +45,12 @@ model: node_encoder_bn: true edge_encoder_bn: true posenc_RWSE: - kernel: - times: 21 + # kernel.times is synced automatically from data.posenc_RWSE.kernel.times pe_dim: 20 raw_norm_type: batchnorm gt: layer_type: GritTransformer - dim_hidden: 116 # must match hidden_size + # dim_hidden is synced automatically from model.hidden_size layer_norm: false batch_norm: true update_e: true diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index 17b1539..aab8b93 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -247,6 +247,24 @@ def __init__(self, args): if not hasattr(args.model, "output_dim"): args.model.output_dim = output_bus_dim + # Sync PE kernel.times from data config into model encoder config so + # users only need to specify it once (under data.posenc_RWSE). + if ( + hasattr(args.data, "posenc_RWSE") + and args.data.posenc_RWSE.enable + and hasattr(args.model, "encoder") + and hasattr(args.model.encoder, "posenc_RWSE") + ): + from gridfm_graphkit.io.param_handler import NestedNamespace + enc_rwse = args.model.encoder.posenc_RWSE + if not hasattr(enc_rwse, "kernel"): + enc_rwse.kernel = NestedNamespace() + enc_rwse.kernel.times = args.data.posenc_RWSE.kernel.times + + # Sync gt.dim_hidden from model.hidden_size so it is specified once. + if hasattr(args.model, "gt"): + args.model.gt.dim_hidden = args.model.hidden_size + # The original homogeneous GRIT # (encoder + optional PE encoders + transformer layers + GraphHead) self.grit = GritTransformer(args) From eba33e52aba4204c07d8dbc3b412e6d500f5e528 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Tue, 10 Mar 2026 14:32:41 -0400 Subject: [PATCH 38/62] flow over time benchmarking Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- scripts/benchmark_model_inference.py | 447 +++++++++++++++++++++++++++ scripts/run_benchmark.sh | 39 +++ 2 files changed, 486 insertions(+) create mode 100644 scripts/benchmark_model_inference.py create mode 100755 scripts/run_benchmark.sh diff --git a/scripts/benchmark_model_inference.py b/scripts/benchmark_model_inference.py new file mode 100644 index 0000000..9a80d14 --- /dev/null +++ b/scripts/benchmark_model_inference.py @@ -0,0 +1,447 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +A unified script for benchmarking and limited custom profiling. Benchmarking columns in the output csv are [batch_size,avg_time_per_sample_ms]. + +Example usagef (edge count is 2*E (branch count)): + +###################################### + +CONF_PATH=../examples/config +OUT_DIR=../scripts +mkdir $OUT_DIR + +python benchmark_model_inference.py --config $CONF_PATH/case30_ieee_base.yaml --num_nodes 30 --num_edges 82 --num_gens 6 --iterations 20 --output_csv $OUT_DIR/case30.csv || true +python benchmark_model_inference.py --config $CONF_PATH/case57_ieee_base.yaml --num_nodes 57 --num_edges 160 --num_gens 7 --iterations 20 --output_csv $OUT_DIR/case57.csv || true +python benchmark_model_inference.py --config $CONF_PATH/case118_ieee_base.yaml --num_nodes 118 --num_edges 372 --num_gens 54 --iterations 20 --output_csv $OUT_DIR/case118.csv || true +python benchmark_model_inference.py --config $CONF_PATH/case500_ieee_base.yaml --num_nodes 500 --num_edges 1466 --num_gens 224 --iterations 20 --output_csv $OUT_DIR/case500.csv || true +python benchmark_model_inference.py --config $CONF_PATH/case2000_ieee_base.yaml --num_nodes 2000 --num_edges 7278 --num_gens 384 --iterations 20 --output_csv $OUT_DIR/case2000.csv || true + +###################################### + +Author(s): Mangaliso M. - mngomezulum@ibm.com + Matteo M. - Not Available +""" + +import os +import time +import csv +import yaml +import torch +import argparse +import platform +from datetime import datetime +from torch_geometric.loader import DataLoader +from torch_geometric.data import HeteroData +from gridfm_graphkit.io.param_handler import NestedNamespace, load_model + +# Optional: tqdm (imported but not required for core flow) +try: + from tqdm import tqdm # noqa: F401 +except Exception: + pass + +# Compilation (kept from original) +import torch._dynamo as dynamo +dynamo.config.suppress_errors = False + +# ---------------------------- +# Argument Parsing +# ---------------------------- +parser = argparse.ArgumentParser(description="Benchmark GNS_final Heterogeneous Model with profiling CSV") +parser.add_argument("--config", type=str, required=True, help="Path to config YAML for model") +parser.add_argument("--num_nodes", type=int, required=True) +parser.add_argument("--num_gens", type=int, required=True) +parser.add_argument("--num_edges", type=int, required=True) +parser.add_argument("--output_csv", type=str, required=True) +parser.add_argument("--iterations", type=int, default=20) +parser.add_argument("--num_workers", type=int, default=0, help="DataLoader num_workers") +parser.add_argument("--pin_memory", action="store_true", help="Enable pin_memory in DataLoader when CUDA is available") +args = parser.parse_args() + +# --- Custom logging (ensure directory exists) +import logging +os.makedirs('logs', exist_ok=True) +logger = logging.getLogger('ibm_benchmark_logger') +logger.setLevel(logging.DEBUG) +logger.propagate = False +file_handler = logging.FileHandler('logs/ibm_bench_logs.log', mode='a') # 'a' for append, 'w' to overwrite +file_handler.setLevel(logging.INFO) +formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +file_handler.setFormatter(formatter) +if not logger.handlers: + logger.addHandler(file_handler) + +# ---------------------------- +# Load Model +# ---------------------------- +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +with open(args.config, "r") as f: + base_config = yaml.safe_load(f) + +config_args = NestedNamespace(**base_config) +model = load_model(config_args).to(device).eval() + +# ---------------------------- +# Parameters +# ---------------------------- +N_BUS = args.num_nodes +N_GEN = args.num_gens +E = args.num_edges +BUS_FEATS = config_args.model.input_bus_dim +GEN_FEATS = config_args.model.input_gen_dim +EDGE_FEATS = config_args.model.edge_dim + +# Keep original batch sizes list +batch_sizes = [1, 2, 4, 8, 16, 32, 64, 96, 128, 256, 512, 640, 768, 1024, 2048, 2560, 3072, 3584, 4096, 6144, 9216, 13824, 17280, 20736, 30000, 35000, 40000, 45000, 50000, 55000, 60000, 65000, 70000, 75000, 80000, 85000, 90000] +iterations = args.iterations + +# ---------------------------- +# Helpers +# ---------------------------- +def now_ms() -> float: + return time.perf_counter() * 1000.0 + +def maybe_cuda_sync(): + if torch.cuda.is_available(): + torch.cuda.synchronize() + +def get_env_info(): + # CPU name detection + cpu_name = None + try: + cpu_name = platform.processor() or None + if not cpu_name and os.path.exists("/proc/cpuinfo"): + with open("/proc/cpuinfo", "r") as f: + for line in f: + if "model name" in line: + cpu_name = line.strip().split(":", 1)[1].strip() + break + if not cpu_name: + cpu_name = platform.uname().machine + except Exception: + cpu_name = "unknown" + + # GPU names and device info + if torch.cuda.is_available(): + try: + gpu_names_list = [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())] + gpu_names = "; ".join(gpu_names_list) + except Exception: + gpu_names = "cuda_available_but_name_unreadable" + device_type = "cuda" + device_name = torch.cuda.get_device_name(0) if torch.cuda.device_count() > 0 else "cuda" + cuda_version_in_torch = torch.version.cuda + cudnn_version = torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else None + else: + # Apple Metal backend? + if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available(): + device_type = "mps" + device_name = "Apple MPS" + gpu_names = "mps" + cuda_version_in_torch = None + cudnn_version = None + else: + device_type = "cpu" + device_name = "cpu" + gpu_names = "none" + cuda_version_in_torch = None + cudnn_version = None + + info = { + "device_type": device_type, + "device_name": device_name, + "gpu_names": gpu_names, + "cpu_name": cpu_name, + "torch_version": torch.__version__, + "cuda_version_in_torch": cuda_version_in_torch, + "cudnn_version": cudnn_version, + "python_version": platform.python_version(), + } + return info + +# ---------------------------- +# Generate Synthetic Hetero Graph +# ---------------------------- +def generate_hetero_graph(): + """ + Generates a dummy heterogeneous power network graph for benchmarking. + + Returns: + data (HeteroData): single self-contained heterogeneous graph with: + - data["bus"].x, data["gen"].x + - edge_index & edge_attr for all relations + - mask_dict inside data.mask_dict + """ + data = HeteroData() + + # Node features + data["bus"].x = torch.randn(N_BUS, BUS_FEATS) + data["gen"].x = torch.randn(N_GEN, GEN_FEATS) + + # Edges: Bus–Bus + src = torch.randint(0, N_BUS, (E,)) + dst = torch.randint(0, N_BUS, (E,)) + data["bus", "connects", "bus"].edge_index = torch.stack([src, dst], dim=0) + data["bus", "connects", "bus"].edge_attr = torch.randn(E, EDGE_FEATS) + + # Edges: Gen–Bus & Bus–Gen + gen_to_bus = torch.randint(0, N_BUS, (N_GEN,)) + + # Gen → Bus + data["gen", "connected_to", "bus"].edge_index = torch.stack( + [torch.arange(N_GEN), gen_to_bus], dim=0 + ) + + # Bus → Gen + data["bus", "connected_to", "gen"].edge_index = torch.stack( + [gen_to_bus, torch.arange(N_GEN)], dim=0 + ) + + # No edge features for these + data["gen", "connected_to", "bus"].edge_attr = None + data["bus", "connected_to", "gen"].edge_attr = None + + # Dummy masks (all True) + mask_bus = torch.ones_like(data["bus"].x, dtype=torch.bool) + mask_gen = torch.ones_like(data["gen"].x, dtype=torch.bool) + bus_types = torch.randint(0, 3, (N_BUS,)) + mask_branch = torch.ones_like(data["bus", "connects", "bus"].edge_attr, dtype=torch.bool) + + mask_PQ = bus_types == 0 + mask_PV = bus_types == 1 + mask_REF = bus_types == 2 + + data.mask_dict = { + "bus": mask_bus, + "gen": mask_gen, + "PQ": mask_PQ, + "PV": mask_PV, + "REF": mask_REF, + "branch": mask_branch + } + return data + +# ---------------------------- +# Benchmark Function +# ---------------------------- +def benchmark(): + # Environment/context info (constant per run) + env = get_env_info() + timestamp = datetime.now().isoformat(timespec='seconds') + + # Measure synthetic graph creation + t0 = now_ms() + data = generate_hetero_graph() + t1 = now_ms() + data_gen_time_ms = t1 - t0 + + # Move the base graph to device (preserve original behavior) + maybe_cuda_sync() + t2 = now_ms() + data = data.to(device) + maybe_cuda_sync() + t3 = now_ms() + graph_to_device_time_ms = t3 - t2 + + batch_sizes_used = [] + times = [] + + header = [ + # Keep original first two columns + "batch_size", + "avg_time_per_sample_ms", + + # Execution config + "num_iters", + "total_samples", + + # Data/IO timing + "data_gen_time_ms", + "graph_to_device_time_ms", + "clone_list_time_ms", + "dataloader_create_time_ms", + "dataloader_first_iter_time_ms", + "batch_to_device_time_ms", + + # Model timing + "warmup_time_ms", + "iter_total_wall_time_ms", + "iter_gpu_time_ms", + "gpu_idle_time_ms", + "gpu_busy_ratio", + "samples_per_sec_wall", + "samples_per_sec_gpu", + "timing_source", # "cuda_event" or "wall_clock" + + # Memory + "max_cuda_mem_alloc_bytes", + "max_cuda_mem_reserved_bytes", + + # Graph & model context + "n_bus", "n_gen", "n_edges", + "bus_feats", "gen_feats", "edge_feats", + + # Runtime context + "device_type", "device_name", + "torch_version", "cuda_version_in_torch", "cudnn_version", + "python_version", + "cpu_name", # NEW + "gpu_names", # NEW + "timestamp_iso", + "num_workers", + "pin_memory", + ] + + with open(args.output_csv, "w", newline="") as csvfile: + writer = csv.writer(csvfile) + writer.writerow(header) + + for batch_size in batch_sizes: + # Build list of graphs (on device, preserving original flow) + maybe_cuda_sync() + t_clone_start = now_ms() + data_list = [data.clone() for _ in range(batch_size)] + maybe_cuda_sync() + t_clone_end = now_ms() + clone_list_time_ms = t_clone_end - t_clone_start + + # Create DataLoader + pin_mem = args.pin_memory and torch.cuda.is_available() + persistent = args.num_workers > 0 + t_dl_create_start = now_ms() + loader = DataLoader( + data_list, + batch_size=batch_size, + num_workers=args.num_workers, + pin_memory=pin_mem, + persistent_workers=persistent, + ) + t_dl_create_end = now_ms() + dataloader_create_time_ms = t_dl_create_end - t_dl_create_start + + # Fetch first batch (collate) + t_iter_start = now_ms() + batch = next(iter(loader)) + t_iter_end = now_ms() + dataloader_first_iter_time_ms = t_iter_end - t_iter_start + + # Ensure batch on device (likely ~0 if items already on device) + maybe_cuda_sync() + t_b2d_start = now_ms() + batch = batch.to(device, non_blocking=True) if torch.cuda.is_available() else batch.to(device) + maybe_cuda_sync() + t_b2d_end = now_ms() + batch_to_device_time_ms = t_b2d_end - t_b2d_start + + test_model = model + + # Warmup (excluded from main timing) + maybe_cuda_sync() + t_warmup_start = now_ms() + with torch.no_grad(): + for _ in range(5): + _ = test_model(batch.x_dict, batch.edge_index_dict, batch.edge_attr_dict, batch.mask_dict) + maybe_cuda_sync() + t_warmup_end = now_ms() + warmup_time_ms = t_warmup_end - t_warmup_start + + num_iters = iterations + total_samples = batch_size * num_iters + + # Reset CUDA memory stats and set up timing + if torch.cuda.is_available(): + torch.cuda.reset_peak_memory_stats(device) + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # Iteration timing + maybe_cuda_sync() + wall_start = now_ms() + with torch.no_grad(): + if torch.cuda.is_available(): + start_event.record() + for _ in range(num_iters): + _ = test_model(batch.x_dict, batch.edge_index_dict, batch.edge_attr_dict, batch.mask_dict) + if torch.cuda.is_available(): + end_event.record() + maybe_cuda_sync() + wall_end = now_ms() + + iter_total_wall_time_ms = wall_end - wall_start + + if torch.cuda.is_available(): + iter_gpu_time_ms = float(start_event.elapsed_time(end_event)) # ms + timing_source = "cuda_event" + avg_time_per_sample_ms = iter_gpu_time_ms / total_samples + gpu_idle_time_ms = max(iter_total_wall_time_ms - iter_gpu_time_ms, 0.0) + gpu_busy_ratio = (iter_gpu_time_ms / iter_total_wall_time_ms) if iter_total_wall_time_ms > 0 else None + max_cuda_mem_alloc_bytes = int(torch.cuda.max_memory_allocated(device)) + max_cuda_mem_reserved_bytes = int(torch.cuda.max_memory_reserved(device)) + samples_per_sec_gpu = (total_samples / (iter_gpu_time_ms / 1000.0)) if iter_gpu_time_ms > 0 else None + else: + iter_gpu_time_ms = None + timing_source = "wall_clock" + avg_time_per_sample_ms = iter_total_wall_time_ms / total_samples + gpu_idle_time_ms = None + gpu_busy_ratio = None + max_cuda_mem_alloc_bytes = None + max_cuda_mem_reserved_bytes = None + samples_per_sec_gpu = None + + samples_per_sec_wall = (total_samples / (iter_total_wall_time_ms / 1000.0)) if iter_total_wall_time_ms > 0 else None + + # Prepare row + row = [ + batch_size, + avg_time_per_sample_ms, + + num_iters, + total_samples, + + data_gen_time_ms, + graph_to_device_time_ms, + clone_list_time_ms, + dataloader_create_time_ms, + dataloader_first_iter_time_ms, + batch_to_device_time_ms, + + warmup_time_ms, + iter_total_wall_time_ms, + iter_gpu_time_ms, + gpu_idle_time_ms, + gpu_busy_ratio, + samples_per_sec_wall, + samples_per_sec_gpu, + timing_source, + + max_cuda_mem_alloc_bytes, + max_cuda_mem_reserved_bytes, + + N_BUS, N_GEN, E, + BUS_FEATS, GEN_FEATS, EDGE_FEATS, + + env["device_type"], env["device_name"], + env["torch_version"], env["cuda_version_in_torch"], env["cudnn_version"], + env["python_version"], + env["cpu_name"], + env["gpu_names"], + timestamp, + args.num_workers, + bool(pin_mem), + ] + + writer.writerow(row) + csvfile.flush() + batch_sizes_used.append(batch_size) + times.append(avg_time_per_sample_ms) + + return batch_sizes_used, times + + +if __name__ == "__main__": + print(f"Starting benchmark for {os.path.basename(args.output_csv)} ..") + benchmark() + print(f"Finished benchmarking for {os.path.basename(args.output_csv)}\n ...") diff --git a/scripts/run_benchmark.sh b/scripts/run_benchmark.sh new file mode 100755 index 0000000..52ae981 --- /dev/null +++ b/scripts/run_benchmark.sh @@ -0,0 +1,39 @@ +#!/bin/bash + +set +e # Do NOT exit on error + +CONFIGS=( + "gridfm01" + "gridfm02" +) + +CONFIG_PATHS=( + "../examples/config/gridFMv0.1_pretraining.yaml" + "../examples/config/gridFMv0.2_pretraining.yaml" +) + +GRAPH_SIZES=( + "30 110" + "300 1120" + "2000 9276" + "3022 11390" + "9241 41337" + "30000 100784" +) + +OUTPUT_DIR="benchmark_results" +mkdir -p $OUTPUT_DIR +for i in "${!CONFIGS[@]}"; do + config_name="${CONFIGS[$i]}" + config_path="${CONFIG_PATHS[$i]}" + for size in "${GRAPH_SIZES[@]}"; do + read -r nodes edges <<< "$size" + output_file="${OUTPUT_DIR}/${config_name}_${nodes}nodes_${edges}edges.csv" + echo "Running benchmark for $config_name with $nodes nodes and $edges edges..." + python benchmark_model_inference.py \ + --config "$config_path" \ + --output_csv "$output_file" \ + --num_nodes "$nodes" \ + --num_edges "$edges" || echo "Failed for $config_name with $nodes nodes" + done +done \ No newline at end of file From 9dd35d6fcab24fa3b9b567a4b09971d181521730 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Thu, 12 Mar 2026 15:20:17 -0400 Subject: [PATCH 39/62] add baseline grit support Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- scripts/benchmark_model_inference.py | 112 +++++++++++++++++++++++---- 1 file changed, 98 insertions(+), 14 deletions(-) diff --git a/scripts/benchmark_model_inference.py b/scripts/benchmark_model_inference.py index 9a80d14..f98e686 100644 --- a/scripts/benchmark_model_inference.py +++ b/scripts/benchmark_model_inference.py @@ -4,7 +4,11 @@ """ A unified script for benchmarking and limited custom profiling. Benchmarking columns in the output csv are [batch_size,avg_time_per_sample_ms]. -Example usagef (edge count is 2*E (branch count)): +Supports two model types via --model flag: + - "hetero" (default): GNS_heterogeneous with HeteroData (bus + gen nodes) + - "grit": GritTransformer with homogeneous Data (single node type) + +Example usage — Heterogeneous GNS (edge count is 2*E (branch count)): ###################################### @@ -12,11 +16,17 @@ OUT_DIR=../scripts mkdir $OUT_DIR -python benchmark_model_inference.py --config $CONF_PATH/case30_ieee_base.yaml --num_nodes 30 --num_edges 82 --num_gens 6 --iterations 20 --output_csv $OUT_DIR/case30.csv || true -python benchmark_model_inference.py --config $CONF_PATH/case57_ieee_base.yaml --num_nodes 57 --num_edges 160 --num_gens 7 --iterations 20 --output_csv $OUT_DIR/case57.csv || true -python benchmark_model_inference.py --config $CONF_PATH/case118_ieee_base.yaml --num_nodes 118 --num_edges 372 --num_gens 54 --iterations 20 --output_csv $OUT_DIR/case118.csv || true -python benchmark_model_inference.py --config $CONF_PATH/case500_ieee_base.yaml --num_nodes 500 --num_edges 1466 --num_gens 224 --iterations 20 --output_csv $OUT_DIR/case500.csv || true -python benchmark_model_inference.py --config $CONF_PATH/case2000_ieee_base.yaml --num_nodes 2000 --num_edges 7278 --num_gens 384 --iterations 20 --output_csv $OUT_DIR/case2000.csv || true +python benchmark_model_inference.py --model hetero --config $CONF_PATH/case30_ieee_base.yaml --num_nodes 30 --num_edges 82 --num_gens 6 --iterations 20 --output_csv $OUT_DIR/case30.csv || true +python benchmark_model_inference.py --model hetero --config $CONF_PATH/case118_ieee_base.yaml --num_nodes 118 --num_edges 372 --num_gens 54 --iterations 20 --output_csv $OUT_DIR/case118.csv || true + +###################################### + +Example usage — GRIT (homogeneous, --num_gens is ignored): + +###################################### + +python benchmark_model_inference.py --model grit --config $CONF_PATH/grit_pretraining.yaml --num_nodes 30 --num_edges 82 --iterations 20 --output_csv $OUT_DIR/grit_case30.csv || true +python benchmark_model_inference.py --model grit --config $CONF_PATH/grit_pretraining.yaml --num_nodes 118 --num_edges 372 --iterations 20 --output_csv $OUT_DIR/grit_case118.csv || true ###################################### @@ -33,7 +43,7 @@ import platform from datetime import datetime from torch_geometric.loader import DataLoader -from torch_geometric.data import HeteroData +from torch_geometric.data import Data, HeteroData from gridfm_graphkit.io.param_handler import NestedNamespace, load_model # Optional: tqdm (imported but not required for core flow) @@ -49,10 +59,13 @@ # ---------------------------- # Argument Parsing # ---------------------------- -parser = argparse.ArgumentParser(description="Benchmark GNS_final Heterogeneous Model with profiling CSV") +parser = argparse.ArgumentParser(description="Benchmark GNN Model inference with profiling CSV") +parser.add_argument("--model", type=str, choices=["hetero", "grit"], default="hetero", + help="Model type: 'hetero' for GNS_heterogeneous, 'grit' for GritTransformer") parser.add_argument("--config", type=str, required=True, help="Path to config YAML for model") parser.add_argument("--num_nodes", type=int, required=True) -parser.add_argument("--num_gens", type=int, required=True) +parser.add_argument("--num_gens", type=int, default=0, + help="Number of generator nodes (required for hetero, ignored for grit)") parser.add_argument("--num_edges", type=int, required=True) parser.add_argument("--output_csv", type=str, required=True) parser.add_argument("--iterations", type=int, default=20) @@ -87,13 +100,30 @@ # ---------------------------- # Parameters # ---------------------------- +MODEL_TYPE = args.model N_BUS = args.num_nodes N_GEN = args.num_gens E = args.num_edges -BUS_FEATS = config_args.model.input_bus_dim -GEN_FEATS = config_args.model.input_gen_dim EDGE_FEATS = config_args.model.edge_dim +if MODEL_TYPE == "hetero": + BUS_FEATS = config_args.model.input_bus_dim + GEN_FEATS = config_args.model.input_gen_dim + NODE_FEATS = None # not used for hetero +else: + # GRIT homogeneous model + NODE_FEATS = config_args.model.input_dim + OUTPUT_DIM = config_args.model.output_dim + MASK_DIM = getattr(config_args.data, "mask_dim", 6) + # Positional encoding config + RRWP_ENABLED = getattr(config_args.data.posenc_RRWP, "enable", False) if hasattr(config_args.data, "posenc_RRWP") else False + RRWP_KSTEPS = getattr(config_args.data.posenc_RRWP, "ksteps", 21) if RRWP_ENABLED else 0 + RWSE_ENABLED = hasattr(config_args.model, "encoder") and getattr(config_args.model.encoder, "node_encoder", False) \ + and "RWSE" in getattr(config_args.model.encoder, "node_encoder_name", "") + RWSE_TIMES = getattr(config_args.model.encoder.posenc_RWSE.kernel, "times", 21) if RWSE_ENABLED else 0 + BUS_FEATS = NODE_FEATS # alias for CSV output compatibility + GEN_FEATS = 0 + # Keep original batch sizes list batch_sizes = [1, 2, 4, 8, 16, 32, 64, 96, 128, 256, 512, 640, 768, 1024, 2048, 2560, 3072, 3584, 4096, 6144, 9216, 13824, 17280, 20736, 30000, 35000, 40000, 45000, 50000, 55000, 60000, 65000, 70000, 75000, 80000, 85000, 90000] iterations = args.iterations @@ -224,6 +254,51 @@ def generate_hetero_graph(): } return data + +# ---------------------------- +# Generate Synthetic Homogeneous Graph (GRIT) +# ---------------------------- +def generate_homo_graph(): + """ + Generates a dummy homogeneous power network graph for GRIT benchmarking. + + Returns: + data (Data): single self-contained homogeneous graph with: + - data.x: node features [N_BUS, NODE_FEATS] + - data.y: target labels [N_BUS, OUTPUT_DIM] + - data.edge_index: [2, E] + - data.edge_attr: [E, EDGE_FEATS] + - data.pestat_RWSE (if RWSE enabled): [N_BUS, RWSE_TIMES] + - data.rrwp, rrwp_index, rrwp_val (if RRWP enabled) + """ + data = Data() + + # Node features: same layout as powergrid_dataset (Pd, Qd, Pg, Qg, Vm, Va, PQ, PV, REF) + data.x = torch.randn(N_BUS, NODE_FEATS) + data.y = data.x[:, :OUTPUT_DIM].clone() + + # Edges + src = torch.randint(0, N_BUS, (E,)) + dst = torch.randint(0, N_BUS, (E,)) + data.edge_index = torch.stack([src, dst], dim=0) + data.edge_attr = torch.randn(E, EDGE_FEATS) + + # RWSE positional encoding (diagonal of random-walk matrix powers) + if RWSE_ENABLED: + data.pestat_RWSE = torch.randn(N_BUS, RWSE_TIMES).abs() + + # RRWP positional / structural encoding + if RRWP_ENABLED: + data.rrwp = torch.randn(N_BUS, RRWP_KSTEPS) + # Sparse RRWP for edges: include existing edges + self-loops + self_loops = torch.arange(N_BUS).unsqueeze(0).repeat(2, 1) + rrwp_idx = torch.cat([data.edge_index, self_loops], dim=1) + rrwp_nnz = rrwp_idx.size(1) + data.rrwp_index = rrwp_idx + data.rrwp_val = torch.randn(rrwp_nnz, RRWP_KSTEPS) + + return data + # ---------------------------- # Benchmark Function # ---------------------------- @@ -234,7 +309,10 @@ def benchmark(): # Measure synthetic graph creation t0 = now_ms() - data = generate_hetero_graph() + if MODEL_TYPE == "hetero": + data = generate_hetero_graph() + else: + data = generate_homo_graph() t1 = now_ms() data_gen_time_ms = t1 - t0 @@ -343,7 +421,10 @@ def benchmark(): t_warmup_start = now_ms() with torch.no_grad(): for _ in range(5): - _ = test_model(batch.x_dict, batch.edge_index_dict, batch.edge_attr_dict, batch.mask_dict) + if MODEL_TYPE == "hetero": + _ = test_model(batch.x_dict, batch.edge_index_dict, batch.edge_attr_dict, batch.mask_dict) + else: + _ = test_model(batch.clone()) maybe_cuda_sync() t_warmup_end = now_ms() warmup_time_ms = t_warmup_end - t_warmup_start @@ -364,7 +445,10 @@ def benchmark(): if torch.cuda.is_available(): start_event.record() for _ in range(num_iters): - _ = test_model(batch.x_dict, batch.edge_index_dict, batch.edge_attr_dict, batch.mask_dict) + if MODEL_TYPE == "hetero": + _ = test_model(batch.x_dict, batch.edge_index_dict, batch.edge_attr_dict, batch.mask_dict) + else: + _ = test_model(batch.clone()) if torch.cuda.is_available(): end_event.record() maybe_cuda_sync() From 91047ccac0b8c8b86c2b84b44e0f2b51810440ca Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Tue, 24 Mar 2026 15:25:31 -0400 Subject: [PATCH 40/62] update benchmarking for new grit format Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/grit_pretraining.yaml | 97 ------------------------ scripts/benchmark_model_inference.py | 104 ++++++++++++-------------- scripts/run_benchmark.sh | 1 + 3 files changed, 49 insertions(+), 153 deletions(-) delete mode 100644 examples/config/grit_pretraining.yaml diff --git a/examples/config/grit_pretraining.yaml b/examples/config/grit_pretraining.yaml deleted file mode 100644 index bf56dff..0000000 --- a/examples/config/grit_pretraining.yaml +++ /dev/null @@ -1,97 +0,0 @@ -callbacks: - patience: 100 - tol: 0 -data: - baseMVA: 100 - learn_mask: false - mask_dim: 6 - mask_ratio: 0.5 - mask_type: rnd - mask_value: -1.0 - networks: - # - Texas2k_case1_2016summerpeak - - case24_ieee_rts - - case118_ieee - - case300_ieee - - case89_pegase - - case240_pserc - normalization: baseMVAnorm - scenarios: - # - 5000 - - 5000 - - 5000 - - 3000 - - 5000 - - 5000 - test_ratio: 0.1 - val_ratio: 0.1 - workers: 4 - posenc_RRWP: - enable: False - ksteps: 21 - add_identity: True - add_node_attr: False - add_inverse: False - posenc_RWSE: - enable: True - kernel: - times: 21 # TODO unify with model -model: - attention_head: 8 - dropout: 0.1 - edge_dim: 2 - hidden_size: 116 # `gt.dim_hidden` must match `gnn.dim_inner` - input_dim: 9 - num_layers: 10 - output_dim: 6 - pe_dim: 20 - type: GRIT - act: relu - encoder: - node_encoder: True - edge_encoder: True - node_encoder_name: RWSE - node_encoder_bn: True - edge_encoder_bn: True - posenc_RWSE: - kernel: - times: 21 - pe_dim: 20 # TODO unify with model.pe_dim - raw_norm_type: batchnorm - gt: - layer_type: GritTransformer - dim_hidden: 116 # `gt.dim_hidden` must match `gnn.dim_inner` - layer_norm: False - batch_norm: True - update_e: True - attn_dropout: 0.2 - attn: - clamp: 5. - act: 'relu' - full_attn: True - edge_enhance: True - O_e: True - norm_e: True - signed_sqrt: True - bn_momentum: 0.1 - bn_no_runner: False - -optimizer: - beta1: 0.9 - beta2: 0.999 - learning_rate: 0.0001 - lr_decay: 0.7 - lr_patience: 10 -seed: 0 -training: - batch_size: 8 - epochs: 500 - loss_weights: - - 0.01 - - 0.99 - losses: - - MaskedMSE - - PBE - accelerator: auto - devices: auto - strategy: auto diff --git a/scripts/benchmark_model_inference.py b/scripts/benchmark_model_inference.py index f98e686..f8c7b76 100644 --- a/scripts/benchmark_model_inference.py +++ b/scripts/benchmark_model_inference.py @@ -6,7 +6,7 @@ Supports two model types via --model flag: - "hetero" (default): GNS_heterogeneous with HeteroData (bus + gen nodes) - - "grit": GritTransformer with homogeneous Data (single node type) + - "grit": GritHeteroAdapter with HeteroData (bus + gen nodes, optional PE attrs) Example usage — Heterogeneous GNS (edge count is 2*E (branch count)): @@ -21,12 +21,12 @@ ###################################### -Example usage — GRIT (homogeneous, --num_gens is ignored): +Example usage — GRIT (HeteroData with PE, --num_gens required): ###################################### -python benchmark_model_inference.py --model grit --config $CONF_PATH/grit_pretraining.yaml --num_nodes 30 --num_edges 82 --iterations 20 --output_csv $OUT_DIR/grit_case30.csv || true -python benchmark_model_inference.py --model grit --config $CONF_PATH/grit_pretraining.yaml --num_nodes 118 --num_edges 372 --iterations 20 --output_csv $OUT_DIR/grit_case118.csv || true +python benchmark_model_inference.py --model grit --config $CONF_PATH/GRIT_PF_datakit_case14.yaml --num_nodes 30 --num_edges 82 --num_gens 6 --iterations 20 --output_csv $OUT_DIR/grit_case30.csv || true +python benchmark_model_inference.py --model grit --config $CONF_PATH/GRIT_PF_datakit_case14.yaml --num_nodes 118 --num_edges 372 --num_gens 54 --iterations 20 --output_csv $OUT_DIR/grit_case118.csv || true ###################################### @@ -43,7 +43,7 @@ import platform from datetime import datetime from torch_geometric.loader import DataLoader -from torch_geometric.data import Data, HeteroData +from torch_geometric.data import HeteroData from gridfm_graphkit.io.param_handler import NestedNamespace, load_model # Optional: tqdm (imported but not required for core flow) @@ -102,27 +102,31 @@ # ---------------------------- MODEL_TYPE = args.model N_BUS = args.num_nodes -N_GEN = args.num_gens E = args.num_edges -EDGE_FEATS = config_args.model.edge_dim -if MODEL_TYPE == "hetero": - BUS_FEATS = config_args.model.input_bus_dim - GEN_FEATS = config_args.model.input_gen_dim - NODE_FEATS = None # not used for hetero -else: - # GRIT homogeneous model - NODE_FEATS = config_args.model.input_dim - OUTPUT_DIM = config_args.model.output_dim - MASK_DIM = getattr(config_args.data, "mask_dim", 6) - # Positional encoding config +# Default num_gens when not provided (shell script omits --num_gens) +N_GEN = args.num_gens if args.num_gens > 0 else max(1, N_BUS // 5) + +EDGE_FEATS = getattr(config_args.model, "edge_dim", 10) + +# Both model types use HeteroData with bus + gen nodes. +# Fall back to input_dim / defaults for configs that lack the hetero keys. +BUS_FEATS = getattr(config_args.model, "input_bus_dim", + getattr(config_args.model, "input_dim", 15)) +GEN_FEATS = getattr(config_args.model, "input_gen_dim", 6) + +if MODEL_TYPE == "grit": + # Positional encoding config (only GRIT uses these) RRWP_ENABLED = getattr(config_args.data.posenc_RRWP, "enable", False) if hasattr(config_args.data, "posenc_RRWP") else False RRWP_KSTEPS = getattr(config_args.data.posenc_RRWP, "ksteps", 21) if RRWP_ENABLED else 0 RWSE_ENABLED = hasattr(config_args.model, "encoder") and getattr(config_args.model.encoder, "node_encoder", False) \ and "RWSE" in getattr(config_args.model.encoder, "node_encoder_name", "") RWSE_TIMES = getattr(config_args.model.encoder.posenc_RWSE.kernel, "times", 21) if RWSE_ENABLED else 0 - BUS_FEATS = NODE_FEATS # alias for CSV output compatibility - GEN_FEATS = 0 +else: + RRWP_ENABLED = False + RRWP_KSTEPS = 0 + RWSE_ENABLED = False + RWSE_TIMES = 0 # Keep original batch sizes list batch_sizes = [1, 2, 4, 8, 16, 32, 64, 96, 128, 256, 512, 640, 768, 1024, 2048, 2560, 3072, 3584, 4096, 6144, 9216, 13824, 17280, 20736, 30000, 35000, 40000, 45000, 50000, 55000, 60000, 65000, 70000, 75000, 80000, 85000, 90000] @@ -211,6 +215,10 @@ def generate_hetero_graph(): data["bus"].x = torch.randn(N_BUS, BUS_FEATS) data["gen"].x = torch.randn(N_GEN, GEN_FEATS) + # Dummy targets + data["bus"].y = torch.randn(N_BUS, BUS_FEATS) + data["gen"].y = torch.randn(N_GEN, GEN_FEATS) + # Edges: Bus–Bus src = torch.randint(0, N_BUS, (E,)) dst = torch.randint(0, N_BUS, (E,)) @@ -258,44 +266,34 @@ def generate_hetero_graph(): # ---------------------------- # Generate Synthetic Homogeneous Graph (GRIT) # ---------------------------- -def generate_homo_graph(): +def generate_grit_graph(): """ - Generates a dummy homogeneous power network graph for GRIT benchmarking. + Generates a dummy heterogeneous graph for GRIT benchmarking. + + GritHeteroAdapter expects a HeteroData batch, so we generate the same + structure as generate_hetero_graph() but also attach PE attributes on + the bus node store when RWSE / RRWP are enabled. Returns: - data (Data): single self-contained homogeneous graph with: - - data.x: node features [N_BUS, NODE_FEATS] - - data.y: target labels [N_BUS, OUTPUT_DIM] - - data.edge_index: [2, E] - - data.edge_attr: [E, EDGE_FEATS] - - data.pestat_RWSE (if RWSE enabled): [N_BUS, RWSE_TIMES] - - data.rrwp, rrwp_index, rrwp_val (if RRWP enabled) + data (HeteroData): heterogeneous graph with bus & gen nodes, + plus optional PE attributes on data["bus"]. """ - data = Data() - - # Node features: same layout as powergrid_dataset (Pd, Qd, Pg, Qg, Vm, Va, PQ, PV, REF) - data.x = torch.randn(N_BUS, NODE_FEATS) - data.y = data.x[:, :OUTPUT_DIM].clone() - - # Edges - src = torch.randint(0, N_BUS, (E,)) - dst = torch.randint(0, N_BUS, (E,)) - data.edge_index = torch.stack([src, dst], dim=0) - data.edge_attr = torch.randn(E, EDGE_FEATS) + data = generate_hetero_graph() - # RWSE positional encoding (diagonal of random-walk matrix powers) + # RWSE positional encoding on bus nodes if RWSE_ENABLED: - data.pestat_RWSE = torch.randn(N_BUS, RWSE_TIMES).abs() + data["bus"].pestat_RWSE = torch.randn(N_BUS, RWSE_TIMES).abs() - # RRWP positional / structural encoding + # RRWP positional / structural encoding on bus nodes if RRWP_ENABLED: - data.rrwp = torch.randn(N_BUS, RRWP_KSTEPS) - # Sparse RRWP for edges: include existing edges + self-loops + data["bus"].rrwp = torch.randn(N_BUS, RRWP_KSTEPS) + # Sparse RRWP for edges: include existing bus-bus edges + self-loops + bb_ei = data["bus", "connects", "bus"].edge_index self_loops = torch.arange(N_BUS).unsqueeze(0).repeat(2, 1) - rrwp_idx = torch.cat([data.edge_index, self_loops], dim=1) + rrwp_idx = torch.cat([bb_ei, self_loops], dim=1) rrwp_nnz = rrwp_idx.size(1) - data.rrwp_index = rrwp_idx - data.rrwp_val = torch.randn(rrwp_nnz, RRWP_KSTEPS) + data["bus"].rrwp_index = rrwp_idx + data["bus"].rrwp_val = torch.randn(rrwp_nnz, RRWP_KSTEPS) return data @@ -312,7 +310,7 @@ def benchmark(): if MODEL_TYPE == "hetero": data = generate_hetero_graph() else: - data = generate_homo_graph() + data = generate_grit_graph() t1 = now_ms() data_gen_time_ms = t1 - t0 @@ -421,10 +419,7 @@ def benchmark(): t_warmup_start = now_ms() with torch.no_grad(): for _ in range(5): - if MODEL_TYPE == "hetero": - _ = test_model(batch.x_dict, batch.edge_index_dict, batch.edge_attr_dict, batch.mask_dict) - else: - _ = test_model(batch.clone()) + _ = test_model(batch.clone()) maybe_cuda_sync() t_warmup_end = now_ms() warmup_time_ms = t_warmup_end - t_warmup_start @@ -445,10 +440,7 @@ def benchmark(): if torch.cuda.is_available(): start_event.record() for _ in range(num_iters): - if MODEL_TYPE == "hetero": - _ = test_model(batch.x_dict, batch.edge_index_dict, batch.edge_attr_dict, batch.mask_dict) - else: - _ = test_model(batch.clone()) + _ = test_model(batch.clone()) if torch.cuda.is_available(): end_event.record() maybe_cuda_sync() diff --git a/scripts/run_benchmark.sh b/scripts/run_benchmark.sh index 52ae981..6483b81 100755 --- a/scripts/run_benchmark.sh +++ b/scripts/run_benchmark.sh @@ -31,6 +31,7 @@ for i in "${!CONFIGS[@]}"; do output_file="${OUTPUT_DIR}/${config_name}_${nodes}nodes_${edges}edges.csv" echo "Running benchmark for $config_name with $nodes nodes and $edges edges..." python benchmark_model_inference.py \ + --model "grit" \ --config "$config_path" \ --output_csv "$output_file" \ --num_nodes "$nodes" \ From deaf640249b74e4b927cd068b7fcfcf7c81f2ce0 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Tue, 24 Mar 2026 15:40:01 -0400 Subject: [PATCH 41/62] cleanup Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/models/__init__.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/gridfm_graphkit/models/__init__.py b/gridfm_graphkit/models/__init__.py index b922274..64d9727 100644 --- a/gridfm_graphkit/models/__init__.py +++ b/gridfm_graphkit/models/__init__.py @@ -1,10 +1,4 @@ -<<<<<<< HEAD -from gridfm_graphkit.models.gps_transformer import GPSTransformer -from gridfm_graphkit.models.gnn_transformer import GNN_TransformerConv -from gridfm_graphkit.models.grit_transformer import GritTransformer -__all__ = ["GPSTransformer", "GNN_TransformerConv", "GRIT"] -======= from gridfm_graphkit.models.gnn_heterogeneous_gns import GNS_heterogeneous from gridfm_graphkit.models.utils import ( PhysicsDecoderOPF, @@ -18,4 +12,4 @@ "PhysicsDecoderPF", "PhysicsDecoderSE", ] ->>>>>>> opensource/main + From b620847fae307c3308d0d96f9261c4db5b146473 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Wed, 25 Mar 2026 08:33:31 -0400 Subject: [PATCH 42/62] finalize connections connection of model Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/models/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/gridfm_graphkit/models/__init__.py b/gridfm_graphkit/models/__init__.py index 64d9727..b30680e 100644 --- a/gridfm_graphkit/models/__init__.py +++ b/gridfm_graphkit/models/__init__.py @@ -1,5 +1,6 @@ from gridfm_graphkit.models.gnn_heterogeneous_gns import GNS_heterogeneous +from gridfm_graphkit.models.grit_transformer import GritHeteroAdapter from gridfm_graphkit.models.utils import ( PhysicsDecoderOPF, PhysicsDecoderPF, @@ -8,6 +9,7 @@ __all__ = [ "GNS_heterogeneous", + "GritHeteroAdapter", "PhysicsDecoderOPF", "PhysicsDecoderPF", "PhysicsDecoderSE", From 964cd5a8577598bfb6c01752572ce664612c3f5b Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Wed, 25 Mar 2026 10:21:24 -0400 Subject: [PATCH 43/62] clean up Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/tasks/compute_ac_dc_metrics.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/gridfm_graphkit/tasks/compute_ac_dc_metrics.py b/gridfm_graphkit/tasks/compute_ac_dc_metrics.py index 8dcfc8c..3d8118c 100644 --- a/gridfm_graphkit/tasks/compute_ac_dc_metrics.py +++ b/gridfm_graphkit/tasks/compute_ac_dc_metrics.py @@ -4,10 +4,6 @@ import os import numpy as np import pandas as pd -from gridfm_datakit.utils.power_balance import ( - compute_branch_powers_vectorized, - compute_bus_balance, -) N_SCENARIO_PER_PARTITION = 200 NUM_PROCESSES = 64 @@ -132,6 +128,11 @@ def compute_ac_dc_metrics( bus_df, branch_df, runtime_df = _load_test_data(data_dir, test_ids) + from gridfm_datakit.utils.power_balance import ( + compute_branch_powers_vectorized, + compute_bus_balance, + ) + # ========================= # AC residuals # ========================= From 028d7c56d0253a059d0e4bfc674faf4f1ef280cb Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Wed, 25 Mar 2026 13:41:06 -0400 Subject: [PATCH 44/62] flow over random masking Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/GRIT_PF_datakit_case14.yaml | 3 +- gridfm_graphkit/datasets/masking.py | 54 +++++++++++++++++++++ gridfm_graphkit/datasets/task_transforms.py | 17 ++++++- 3 files changed, 71 insertions(+), 3 deletions(-) diff --git a/examples/config/GRIT_PF_datakit_case14.yaml b/examples/config/GRIT_PF_datakit_case14.yaml index d1ecdca..7f89c56 100644 --- a/examples/config/GRIT_PF_datakit_case14.yaml +++ b/examples/config/GRIT_PF_datakit_case14.yaml @@ -5,7 +5,8 @@ task: task_name: PowerFlow data: baseMVA: 100 - mask_value: 0.0 + mask_type: rnd # or determinstic + mask_ratio: 0.5 # for random masking only normalization: HeteroDataMVANormalizer networks: - case14_ieee diff --git a/gridfm_graphkit/datasets/masking.py b/gridfm_graphkit/datasets/masking.py index b615e0c..6b9d3e4 100644 --- a/gridfm_graphkit/datasets/masking.py +++ b/gridfm_graphkit/datasets/masking.py @@ -33,6 +33,60 @@ from torch_geometric.nn import MessagePassing +class AddRandomHeteroMask(BaseTransform): + """Creates random masks for self-supervised pretraining on heterogeneous power grid graphs. + + Each selected feature dimension is independently masked per node/edge with + probability ``mask_ratio``. Masked bus features: VM, VA, QG. Masked gen + features: PG. Masked branch features: P_E, Q_E. + + The output ``data.mask_dict`` has the same structure as the deterministic + PF / OPF masks so that downstream losses (``MaskedBusMSE``, ``MaskedGenMSE``, + ``PBELoss``, etc.) work without modification. + """ + + def __init__(self, mask_ratio=0.5): + super().__init__() + self.mask_ratio = mask_ratio + + def forward(self, data): + bus_x = data.x_dict["bus"] + gen_x = data.x_dict["gen"] + + # Bus type indicators (needed by losses and test metrics) + mask_PQ = bus_x[:, PQ_H] == 1 + mask_PV = bus_x[:, PV_H] == 1 + mask_REF = bus_x[:, REF_H] == 1 + + # Random bus mask on variable features the model reconstructs + mask_bus = torch.zeros_like(bus_x, dtype=torch.bool) + n_bus = bus_x.size(0) + for feat_idx in (VM_H, VA_H, QG_H): + mask_bus[:, feat_idx] = torch.rand(n_bus) < self.mask_ratio + + # Random gen mask on PG + mask_gen = torch.zeros_like(gen_x, dtype=torch.bool) + mask_gen[:, PG_H] = torch.rand(gen_x.size(0)) < self.mask_ratio + + # Random branch mask on flow features + branch_attr = data.edge_attr_dict[("bus", "connects", "bus")] + mask_branch = torch.zeros_like(branch_attr, dtype=torch.bool) + n_edge = branch_attr.size(0) + for feat_idx in (P_E, Q_E): + mask_branch[:, feat_idx] = torch.rand(n_edge) < self.mask_ratio + + data.mask_dict = { + "bus": mask_bus, + "gen": mask_gen, + "branch": mask_branch, + "PQ": mask_PQ, + "PV": mask_PV, + "REF": mask_REF, + } + + return data + + class AddPFHeteroMask(BaseTransform): """Creates masks for a heterogeneous power flow graph.""" diff --git a/gridfm_graphkit/datasets/task_transforms.py b/gridfm_graphkit/datasets/task_transforms.py index eaaca66..d6f1b8c 100644 --- a/gridfm_graphkit/datasets/task_transforms.py +++ b/gridfm_graphkit/datasets/task_transforms.py @@ -8,6 +8,7 @@ from gridfm_graphkit.datasets.masking import ( AddOPFHeteroMask, AddPFHeteroMask, + AddRandomHeteroMask, SimulateMeasurements, ) from gridfm_graphkit.io.registries import TRANSFORM_REGISTRY @@ -20,7 +21,13 @@ def __init__(self, args): transforms.append(RemoveInactiveBranches()) transforms.append(RemoveInactiveGenerators()) - transforms.append(AddPFHeteroMask()) + + mask_type = getattr(args.data, "mask_type", None) + if mask_type == "rnd": + transforms.append(AddRandomHeteroMask(mask_ratio=args.data.mask_ratio)) + else: + transforms.append(AddPFHeteroMask()) + transforms.append(ApplyMasking(args=args)) # Pass the list of transforms to Compose @@ -34,7 +41,13 @@ def __init__(self, args): transforms.append(RemoveInactiveBranches()) transforms.append(RemoveInactiveGenerators()) - transforms.append(AddOPFHeteroMask()) + + mask_type = getattr(args.data, "mask_type", None) + if mask_type == "rnd": + transforms.append(AddRandomHeteroMask(mask_ratio=args.data.mask_ratio)) + else: + transforms.append(AddOPFHeteroMask()) + transforms.append(ApplyMasking(args=args)) # Pass the list of transforms to Compose From 78253bae6d0884426a6e65c0e2cd441b3f186e45 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Wed, 25 Mar 2026 13:59:50 -0400 Subject: [PATCH 45/62] clean up Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- scripts/benchmark_model_inference.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/scripts/benchmark_model_inference.py b/scripts/benchmark_model_inference.py index f8c7b76..0199c6e 100644 --- a/scripts/benchmark_model_inference.py +++ b/scripts/benchmark_model_inference.py @@ -16,8 +16,8 @@ OUT_DIR=../scripts mkdir $OUT_DIR -python benchmark_model_inference.py --model hetero --config $CONF_PATH/case30_ieee_base.yaml --num_nodes 30 --num_edges 82 --num_gens 6 --iterations 20 --output_csv $OUT_DIR/case30.csv || true -python benchmark_model_inference.py --model hetero --config $CONF_PATH/case118_ieee_base.yaml --num_nodes 118 --num_edges 372 --num_gens 54 --iterations 20 --output_csv $OUT_DIR/case118.csv || true +python benchmark_model_inference.py --model hetero --config $CONF_PATH/HGNS_PF_datakit_case30.yaml --num_nodes 30 --num_edges 82 --num_gens 6 --iterations 20 --output_csv $OUT_DIR/case30.csv || true +python benchmark_model_inference.py --model hetero --config $CONF_PATH/HGNS_PF_datakit_case118.yaml --num_nodes 118 --num_edges 372 --num_gens 54 --iterations 20 --output_csv $OUT_DIR/case118.csv || true ###################################### @@ -117,11 +117,11 @@ if MODEL_TYPE == "grit": # Positional encoding config (only GRIT uses these) + # Read enablement and dimensions from data config (canonical source). RRWP_ENABLED = getattr(config_args.data.posenc_RRWP, "enable", False) if hasattr(config_args.data, "posenc_RRWP") else False RRWP_KSTEPS = getattr(config_args.data.posenc_RRWP, "ksteps", 21) if RRWP_ENABLED else 0 - RWSE_ENABLED = hasattr(config_args.model, "encoder") and getattr(config_args.model.encoder, "node_encoder", False) \ - and "RWSE" in getattr(config_args.model.encoder, "node_encoder_name", "") - RWSE_TIMES = getattr(config_args.model.encoder.posenc_RWSE.kernel, "times", 21) if RWSE_ENABLED else 0 + RWSE_ENABLED = hasattr(config_args.data, "posenc_RWSE") and getattr(config_args.data.posenc_RWSE, "enable", False) + RWSE_TIMES = getattr(config_args.data.posenc_RWSE.kernel, "times", 21) if RWSE_ENABLED else 0 else: RRWP_ENABLED = False RRWP_KSTEPS = 0 From 72a744930f3d91a3e11e4cd27e3ed3e9a7983cff Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Wed, 25 Mar 2026 14:12:41 -0400 Subject: [PATCH 46/62] clean up Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- scripts/benchmark_model_inference.py | 2 +- scripts/run_benchmark.sh | 8 +++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/scripts/benchmark_model_inference.py b/scripts/benchmark_model_inference.py index f8c7b76..878cbbd 100644 --- a/scripts/benchmark_model_inference.py +++ b/scripts/benchmark_model_inference.py @@ -129,7 +129,7 @@ RWSE_TIMES = 0 # Keep original batch sizes list -batch_sizes = [1, 2, 4, 8, 16, 32, 64, 96, 128, 256, 512, 640, 768, 1024, 2048, 2560, 3072, 3584, 4096, 6144, 9216, 13824, 17280, 20736, 30000, 35000, 40000, 45000, 50000, 55000, 60000, 65000, 70000, 75000, 80000, 85000, 90000] +batch_sizes = [1, 2, 4, 8, 16, 32] iterations = args.iterations # ---------------------------- diff --git a/scripts/run_benchmark.sh b/scripts/run_benchmark.sh index 6483b81..744cfa2 100755 --- a/scripts/run_benchmark.sh +++ b/scripts/run_benchmark.sh @@ -3,13 +3,11 @@ set +e # Do NOT exit on error CONFIGS=( - "gridfm01" - "gridfm02" + "grit01" ) CONFIG_PATHS=( - "../examples/config/gridFMv0.1_pretraining.yaml" - "../examples/config/gridFMv0.2_pretraining.yaml" + "../examples/config/r2-1_grit_pretraining_RWSE_multi.yaml" ) GRAPH_SIZES=( @@ -37,4 +35,4 @@ for i in "${!CONFIGS[@]}"; do --num_nodes "$nodes" \ --num_edges "$edges" || echo "Failed for $config_name with $nodes nodes" done -done \ No newline at end of file +done From 549a525199502f52c73e0a85234aae76b0a80688 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Wed, 25 Mar 2026 14:32:33 -0400 Subject: [PATCH 47/62] adjust example parameters Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/GRIT_PF_datakit_case14.yaml | 4 ++-- scripts/benchmark_model_inference.py | 2 ++ scripts/run_benchmark.sh | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/config/GRIT_PF_datakit_case14.yaml b/examples/config/GRIT_PF_datakit_case14.yaml index 7f89c56..70813ce 100644 --- a/examples/config/GRIT_PF_datakit_case14.yaml +++ b/examples/config/GRIT_PF_datakit_case14.yaml @@ -28,7 +28,7 @@ model: # edge_dim must match the bus-bus edge feature count after transforms # (P_E, Q_E, YFF_TT_R, YFF_TT_I, YFT_TF_R, YFT_TF_I, TAP, ANG_MIN, ANG_MAX, RATE_A) edge_dim: 10 - hidden_size: 116 + hidden_size: 496 # input_dim = bus feature count (used by GRIT core FeatureEncoder) input_dim: 15 # Hetero adapter head dimensions @@ -36,7 +36,7 @@ model: input_gen_dim: 6 output_bus_dim: 2 output_gen_dim: 1 - num_layers: 10 + num_layers: 7 type: GRIT act: relu encoder: diff --git a/scripts/benchmark_model_inference.py b/scripts/benchmark_model_inference.py index e9af4c6..fe3010f 100644 --- a/scripts/benchmark_model_inference.py +++ b/scripts/benchmark_model_inference.py @@ -96,6 +96,8 @@ config_args = NestedNamespace(**base_config) model = load_model(config_args).to(device).eval() +tot_params = sum(p.numel() for p in model.parameters() if p.requires_grad) +print("**Total model trainable params: {}".format(tot_params)) # ---------------------------- # Parameters diff --git a/scripts/run_benchmark.sh b/scripts/run_benchmark.sh index 744cfa2..2054f30 100755 --- a/scripts/run_benchmark.sh +++ b/scripts/run_benchmark.sh @@ -7,7 +7,7 @@ CONFIGS=( ) CONFIG_PATHS=( - "../examples/config/r2-1_grit_pretraining_RWSE_multi.yaml" + "../examples/config/GRIT_PF_datakit_case14.yaml" ) GRAPH_SIZES=( From c181d1206b9c1364a557b5d2fa548f48e02fb52b Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Fri, 27 Mar 2026 08:33:44 -0400 Subject: [PATCH 48/62] update GRIT wrapper Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/models/grit_transformer.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index aab8b93..82fef9d 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -126,7 +126,7 @@ class GritTransformer(torch.nn.Module): 2023. """ - def __init__(self, args): + def __init__(self, args, include_decoder=True): super().__init__() @@ -195,7 +195,8 @@ def __init__(self, args): self.layers = nn.Sequential(*layers) - self.decoder = GraphHead(dim_inner, dim_out) + if include_decoder: + self.decoder = GraphHead(dim_inner, dim_out) def forward(self, batch): """ @@ -266,8 +267,9 @@ def __init__(self, args): args.model.gt.dim_hidden = args.model.hidden_size # The original homogeneous GRIT - # (encoder + optional PE encoders + transformer layers + GraphHead) - self.grit = GritTransformer(args) + # (encoder + optional PE encoders + transformer layers) + # Decoder is excluded — this adapter provides its own per-type heads. + self.grit = GritTransformer(args, include_decoder=False) # Per-node-type output heads (replace GraphHead for hetero output) self.bus_head = nn.Sequential( From d4405fd63d50069b6ecb3c3e90c697b582867c57 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Fri, 27 Mar 2026 09:25:33 -0400 Subject: [PATCH 49/62] update GRIT wrapper Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/models/grit_transformer.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index 82fef9d..99519eb 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -178,6 +178,11 @@ def __init__(self, args, include_decoder=True): layers = [] for ll in range(num_layers): + # The last layer's edge output is never consumed downstream + # (only node features feed into the output heads), so skip + # creating O_e / norm_e parameters to avoid DDP unused-parameter + # errors. + is_last = (ll == num_layers - 1) layers.append(GritTransformerLayer( in_dim=args.model.gt.dim_hidden, out_dim=args.model.gt.dim_hidden, @@ -188,8 +193,8 @@ def __init__(self, args, include_decoder=True): layer_norm=args.model.gt.layer_norm, batch_norm=args.model.gt.batch_norm, residual=True, - norm_e=args.model.gt.attn.norm_e, - O_e=args.model.gt.attn.O_e, + norm_e=False if is_last else args.model.gt.attn.norm_e, + O_e=False if is_last else args.model.gt.attn.O_e, cfg=args.model.gt, )) From 2452df3e6798a9e39ac52f229b7eea328ec45273 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Fri, 27 Mar 2026 16:54:21 -0400 Subject: [PATCH 50/62] add edge norm Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/GRIT_PF_datakit_case14.yaml | 1 + gridfm_graphkit/models/grit_transformer.py | 23 ++++++++++++++++++++- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/examples/config/GRIT_PF_datakit_case14.yaml b/examples/config/GRIT_PF_datakit_case14.yaml index 70813ce..8ce3a97 100644 --- a/examples/config/GRIT_PF_datakit_case14.yaml +++ b/examples/config/GRIT_PF_datakit_case14.yaml @@ -7,6 +7,7 @@ data: baseMVA: 100 mask_type: rnd # or determinstic mask_ratio: 0.5 # for random masking only + mask_value: -1 normalization: HeteroDataMVANormalizer networks: - case14_ieee diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index 99519eb..eaaf11f 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -29,6 +29,27 @@ def forward(self, batch): return batch +class BatchNorm1dEdge(torch.nn.Module): + r"""A batch normalization layer for edge-level features. + + Args: + dim_in (int): BatchNorm input dimension. + eps (float): BatchNorm eps. + momentum (float): BatchNorm momentum. + """ + def __init__(self, dim_in, eps, momentum): + super().__init__() + self.bn = torch.nn.BatchNorm1d( + dim_in, + eps=eps, + momentum=momentum, + ) + + def forward(self, batch): + batch.edge_attr = self.bn(batch.edge_attr) + return batch + + class LinearNodeEncoder(torch.nn.Module): def __init__(self, dim_in, emb_dim): super().__init__() @@ -84,7 +105,7 @@ def __init__( # Encode integer edge features via nn.Embeddings self.edge_encoder = LinearEdgeEncoder(edge_dim, enc_dim_edge) if args.encoder.edge_encoder_bn: - self.edge_encoder_bn = BatchNorm1dNode(enc_dim_edge, 1e-5, 0.1) + self.edge_encoder_bn = BatchNorm1dEdge(enc_dim_edge, 1e-5, 0.1) def forward(self, batch): for module in self.children(): From ef2ce5807dd90da87d3c82aa61bdae28509c7d2d Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 30 Mar 2026 10:18:41 -0400 Subject: [PATCH 51/62] extend random masking to PD QD PG Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/GRIT_PF_datakit_case14.yaml | 6 +- gridfm_graphkit/datasets/globals.py | 2 + gridfm_graphkit/datasets/masking.py | 8 +-- gridfm_graphkit/datasets/normalizers.py | 2 + gridfm_graphkit/models/grit_transformer.py | 50 ++++++++++++- gridfm_graphkit/training/loss.py | 77 +++++++++++++++++++++ 6 files changed, 137 insertions(+), 8 deletions(-) diff --git a/examples/config/GRIT_PF_datakit_case14.yaml b/examples/config/GRIT_PF_datakit_case14.yaml index 8ce3a97..909c616 100644 --- a/examples/config/GRIT_PF_datakit_case14.yaml +++ b/examples/config/GRIT_PF_datakit_case14.yaml @@ -78,9 +78,9 @@ training: batch_size: 8 epochs: 500 loss_weights: - - 0.01 - - 0.09 - - 0.9 + - 0.99 + - 0.001 + - 0.009 losses: - PBE - MaskedGenMSE diff --git a/gridfm_graphkit/datasets/globals.py b/gridfm_graphkit/datasets/globals.py index ab3c7e3..9627cd8 100644 --- a/gridfm_graphkit/datasets/globals.py +++ b/gridfm_graphkit/datasets/globals.py @@ -24,6 +24,8 @@ VA_OUT = 1 PG_OUT = 2 QG_OUT = 3 +PD_OUT = 4 # for random masking +QD_OUT = 5 # for random masking PG_OUT_GEN = 0 diff --git a/gridfm_graphkit/datasets/masking.py b/gridfm_graphkit/datasets/masking.py index 6b9d3e4..01924a2 100644 --- a/gridfm_graphkit/datasets/masking.py +++ b/gridfm_graphkit/datasets/masking.py @@ -37,11 +37,11 @@ class AddRandomHeteroMask(BaseTransform): """Creates random masks for self-supervised pretraining on heterogeneous power grid graphs. Each selected feature dimension is independently masked per node/edge with - probability ``mask_ratio``. Masked bus features: VM, VA, QG. Masked gen - features: PG. Masked branch features: P_E, Q_E. + probability ``mask_ratio``. Masked bus features: PD, QD, VM, VA, QG. + Masked gen features: PG. Masked branch features: P_E, Q_E. The output ``data.mask_dict`` has the same structure as the deterministic - PF / OPF masks so that downstream losses (``MaskedBusMSE``, ``MaskedGenMSE``, + PF / OPF masks so that downstream losses (``MaskedReconstructionMSE``, ``PBELoss``, etc.) work without modification. """ @@ -61,7 +61,7 @@ def forward(self, data): # Random bus mask on variable features the model reconstructs mask_bus = torch.zeros_like(bus_x, dtype=torch.bool) n_bus = bus_x.size(0) - for feat_idx in (VM_H, VA_H, QG_H): + for feat_idx in (PD_H, QD_H, VM_H, VA_H, QG_H): mask_bus[:, feat_idx] = torch.rand(n_bus) < self.mask_ratio # Random gen mask on PG diff --git a/gridfm_graphkit/datasets/normalizers.py b/gridfm_graphkit/datasets/normalizers.py index 11601a6..6189404 100644 --- a/gridfm_graphkit/datasets/normalizers.py +++ b/gridfm_graphkit/datasets/normalizers.py @@ -306,6 +306,8 @@ def inverse_output(self, output, batch): gen_output = output["gen"] bus_output[:, PG_OUT] *= self.baseMVA bus_output[:, QG_OUT] *= self.baseMVA + bus_output[:, PD_OUT] *= self.baseMVA # for random masking + bus_output[:, QD_OUT] *= self.baseMVA # for random masking gen_output[:, PG_OUT_GEN] *= self.baseMVA def get_stats(self) -> dict: diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index eaaf11f..5e20f18 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -6,6 +6,8 @@ from gridfm_graphkit.models.rrwp_encoder import RRWPLinearNodeEncoder, RRWPLinearEdgeEncoder from gridfm_graphkit.models.grit_layer import GritTransformerLayer from gridfm_graphkit.models.kernel_pos_encoder import RWSENodeEncoder +from torch_scatter import scatter_add +from gridfm_graphkit.datasets.globals import PG_H class BatchNorm1dNode(torch.nn.Module): @@ -242,6 +244,48 @@ def forward(self, batch): return batch +def aggregate_pg(batch, mask_value=-1.0): + """Aggregate per-generator active power (PG) onto bus nodes. + + In the homogeneous reference, PG is a direct bus feature visible to the + transformer alongside Pd, Qd, Vm, Va, etc. In the heterogeneous + representation PG lives on separate generator nodes, so it must be + aggregated onto buses before the transformer can learn voltage-power + coupling. + + Masked generators (where PG has been replaced by the mask value) are + excluded from the sum to avoid corrupting the aggregated signal. Buses + where *all* connected generators are masked receive the mask value + instead, preserving a consistent "unknown" indicator. + """ + gen_to_bus = batch["gen", "connected_to", "bus"].edge_index + gen_pg = batch["gen"].x[:, PG_H] + gen_masked = batch.mask_dict["gen"][:, PG_H] # True = masked + + # Zero out masked generators so they don't contribute to the sum + pg_clean = torch.where(gen_masked, torch.zeros_like(gen_pg), gen_pg) + + pg_per_bus = scatter_add( + pg_clean, + gen_to_bus[1], + dim=0, + dim_size=batch["bus"].x.size(0), + ) + + # Check which buses have ALL generators masked (or no generators at all) + unmasked_count = scatter_add( + (~gen_masked).float(), + gen_to_bus[1], + dim=0, + dim_size=batch["bus"].x.size(0), + ) + all_masked = unmasked_count == 0 + + # Set mask_value for fully-masked buses + pg_per_bus[all_masked] = mask_value + + return pg_per_bus + @MODELS_REGISTRY.register("GRIT") class GritHeteroAdapter(torch.nn.Module): @@ -321,8 +365,12 @@ def forward(self, batch): predicted output features. """ # --- Extract bus-only homogeneous subgraph --- + # Aggregate generator PG onto buses + pg_per_bus = aggregate_pg(batch, mask_value=self.grit.mask_value[0].item()) + bus_x = torch.cat([batch["bus"].x, pg_per_bus.unsqueeze(-1)], dim=-1) # 15 → 16D + homo = Data( - x=batch["bus"].x, + x=bus_x, y=batch["bus"].y, edge_index=batch["bus", "connects", "bus"].edge_index, edge_attr=batch["bus", "connects", "bus"].edge_attr, diff --git a/gridfm_graphkit/training/loss.py b/gridfm_graphkit/training/loss.py index 02df3bc..f28370d 100644 --- a/gridfm_graphkit/training/loss.py +++ b/gridfm_graphkit/training/loss.py @@ -18,6 +18,8 @@ VA_OUT, QG_OUT, PG_OUT, + PD_OUT, + QD_OUT, # Generator feature indices PG_H, # Edge feature indices @@ -142,6 +144,81 @@ def forward( return {"loss": loss, "Masked bus MSE loss": loss.detach()} +@LOSS_REGISTRY.register("MaskedReconstructionMSE") +class MaskedReconstructionMSE(BaseLoss): + """Unified masked MSE over bus-level quantities [VM, VA, PG, QG, PD, QD]. + + Mirrors the homogeneous reference MaskedMSE by combining bus predictions + and aggregated generator PG into a single prediction/target/mask tensor. + PG targets are aggregated from generator ground truth onto buses via + scatter_add; the bus-level PG mask is True when any generator at the bus + is masked, indicating that the model must reconstruct that quantity. + + Replaces the separate MaskedBusMSE + MaskedGenMSE pair. + Requires output_bus_dim >= 6 so the bus head predicts + [VM, VA, PG, QG, PD, QD]. + """ + + def __init__(self, loss_args, args): + super().__init__() + self.reduction = "mean" + + def forward( + self, + pred_dict, + target_dict, + edge_index_dict, + edge_attr_dict, + mask_dict, + model=None, + ): + pred_bus = pred_dict["bus"] + target_bus = target_dict["bus"] + num_bus = target_bus.size(0) + gen_to_bus_ei = edge_index_dict[("gen", "connected_to", "bus")] + + # --- Build target: [VM, VA, PG_agg, QG, PD, QD] --- + target_pg_agg = scatter_add( + target_dict["gen"][:, PG_H], + gen_to_bus_ei[1], + dim=0, + dim_size=num_bus, + ) + target = torch.stack([ + target_bus[:, VM_H], + target_bus[:, VA_H], + target_pg_agg, + target_bus[:, QG_H], + target_bus[:, PD_H], + target_bus[:, QD_H], + ], dim=1) + + # --- Build mask: [N_bus, 6] --- + # PG bus-level mask: True if any generator at the bus has PG masked + gen_pg_masked = mask_dict["gen"][:, PG_H].float() + any_gen_masked = scatter_add( + gen_pg_masked, + gen_to_bus_ei[1], + dim=0, + dim_size=num_bus, + ) > 0 + + mask = torch.stack([ + mask_dict["bus"][:, VM_H], + mask_dict["bus"][:, VA_H], + any_gen_masked, + mask_dict["bus"][:, QG_H], + mask_dict["bus"][:, PD_H], + mask_dict["bus"][:, QD_H], + ], dim=1) + + # --- Prediction: [VM, VA, PG, QG, PD, QD] from bus head --- + pred = pred_bus[:, [VM_OUT, VA_OUT, PG_OUT, QG_OUT, PD_OUT, QD_OUT]] + + loss = F.mse_loss(pred[mask], target[mask], reduction=self.reduction) + return {"loss": loss, "Masked reconstruction MSE loss": loss.detach()} + + @LOSS_REGISTRY.register("MSE") class MSELoss(BaseLoss): """Standard Mean Squared Error loss.""" From 55182a65f1dc6040e628945fa8418613711944eb Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 30 Mar 2026 10:25:16 -0400 Subject: [PATCH 52/62] extend random masking to PD QD PG Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/GRIT_PF_datakit_case14.yaml | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/examples/config/GRIT_PF_datakit_case14.yaml b/examples/config/GRIT_PF_datakit_case14.yaml index 909c616..88857f1 100644 --- a/examples/config/GRIT_PF_datakit_case14.yaml +++ b/examples/config/GRIT_PF_datakit_case14.yaml @@ -30,12 +30,12 @@ model: # (P_E, Q_E, YFF_TT_R, YFF_TT_I, YFT_TF_R, YFT_TF_I, TAP, ANG_MIN, ANG_MAX, RATE_A) edge_dim: 10 hidden_size: 496 - # input_dim = bus feature count (used by GRIT core FeatureEncoder) - input_dim: 15 + # input_dim = bus feature count + aggregated PG (used by GRIT core FeatureEncoder) + input_dim: 16 # Hetero adapter head dimensions - input_bus_dim: 15 + input_bus_dim: 16 input_gen_dim: 6 - output_bus_dim: 2 + output_bus_dim: 6 # [VM, VA, PG, QG, PD, QD] output_gen_dim: 1 num_layers: 7 type: GRIT @@ -79,16 +79,13 @@ training: epochs: 500 loss_weights: - 0.99 - - 0.001 - - 0.009 + - 0.01 losses: - PBE - - MaskedGenMSE - - MaskedBusMSE + - MaskedReconstructionMSE loss_args: - {} - {} - - {} accelerator: auto devices: auto strategy: auto From 85b4ddfea754f18f662a041c5153e18e066a2033 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 30 Mar 2026 12:40:18 -0400 Subject: [PATCH 53/62] update PBLoss to support transformer wrapper Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/training/loss.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/gridfm_graphkit/training/loss.py b/gridfm_graphkit/training/loss.py index f28370d..fe4106b 100644 --- a/gridfm_graphkit/training/loss.py +++ b/gridfm_graphkit/training/loss.py @@ -469,14 +469,23 @@ def forward( S_injection = torch.diag(V) @ Y_bus_conj @ V_conj # --- Net power from predictions/targets --- - # Pg: aggregate generator predictions onto buses + # Pg: use bus head prediction where masked, ground truth where known. + # Ground truth is aggregated from generator targets onto buses. gen_to_bus_ei = edge_index_dict[("gen", "connected_to", "bus")] - Pg_per_bus = scatter_add( - pred_dict["gen"].squeeze(-1), + target_pg_agg = scatter_add( + target_dict["gen"][:, PG_H], gen_to_bus_ei[1], dim=0, dim_size=num_bus, ) + gen_pg_masked = mask_dict["gen"][:, PG_H].float() + any_gen_masked = scatter_add( + gen_pg_masked, + gen_to_bus_ei[1], + dim=0, + dim_size=num_bus, + ) > 0 + Pg_per_bus = torch.where(any_gen_masked, pred_bus[:, PG_OUT], target_pg_agg) Pd = target_bus[:, PD_H] Qd = target_bus[:, QD_H] From 9e2a3133661f21e1efb1bfdfd8bd67a3542f7816 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 30 Mar 2026 12:55:23 -0400 Subject: [PATCH 54/62] update PBLoss to support transformer wrapper Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- examples/config/GRIT_PF_datakit_case14.yaml | 2 +- gridfm_graphkit/models/grit_transformer.py | 18 ++++++++++++------ 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/examples/config/GRIT_PF_datakit_case14.yaml b/examples/config/GRIT_PF_datakit_case14.yaml index 88857f1..cef4b50 100644 --- a/examples/config/GRIT_PF_datakit_case14.yaml +++ b/examples/config/GRIT_PF_datakit_case14.yaml @@ -36,7 +36,7 @@ model: input_bus_dim: 16 input_gen_dim: 6 output_bus_dim: 6 # [VM, VA, PG, QG, PD, QD] - output_gen_dim: 1 + output_gen_dim: 0 # PG predicted at bus level; no per-generator head needed num_layers: 7 type: GRIT act: relu diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index 5e20f18..3422a21 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -347,11 +347,17 @@ def __init__(self, args): nn.LeakyReLU(), nn.Linear(dim_inner, output_bus_dim), ) - self.gen_head = nn.Sequential( - nn.Linear(input_gen_dim, dim_inner), - nn.LeakyReLU(), - nn.Linear(dim_inner, output_gen_dim), - ) + # gen_head is only needed for tasks that require per-generator + # predictions (e.g. OPF cost computation). When output_gen_dim is 0 + # or not set, skip it to avoid DDP unused-parameter errors. + if output_gen_dim and output_gen_dim > 0: + self.gen_head = nn.Sequential( + nn.Linear(input_gen_dim, dim_inner), + nn.LeakyReLU(), + nn.Linear(dim_inner, output_gen_dim), + ) + else: + self.gen_head = None def forward(self, batch): """Forward pass on a heterogeneous power-grid batch. @@ -392,6 +398,6 @@ def forward(self, batch): # --- Per-type decoding --- bus_out = self.bus_head(homo.x) - gen_out = self.gen_head(batch["gen"].x) + gen_out = self.gen_head(batch["gen"].x) if self.gen_head is not None else batch["gen"].x return {"bus": bus_out, "gen": gen_out} From 6f112d89e4a99945ddd4a678afc9c63a86558b81 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 30 Mar 2026 13:14:28 -0400 Subject: [PATCH 55/62] patch for admittance matrix indicies in PBLoss Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/training/loss.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/gridfm_graphkit/training/loss.py b/gridfm_graphkit/training/loss.py index fe4106b..819b2e6 100644 --- a/gridfm_graphkit/training/loss.py +++ b/gridfm_graphkit/training/loss.py @@ -455,8 +455,13 @@ def forward( V_conj = torch.conj(V) # --- Admittance matrix from bus-bus edge attrs --- - # Use Yff (diagonal-block) real/imag as the admittance entries - edge_complex = bus_edge_attr[:, YFF_TT_R] + 1j * bus_edge_attr[:, YFF_TT_I] + # Off-diagonal entries of Y-bus: Y[from][to] = Yft, Y[to][from] = Ytf. + # The dataset stores forward edges with Yft at YFT_TF columns and + # reverse edges with Ytf at the same columns, so indexing YFT_TF_R/I + # gives the correct off-diagonal admittance for both directions. + # (YFF_TT columns hold diagonal-block entries Yff/Ytt which belong on + # the Y-bus diagonal, not at off-diagonal edge positions.) + edge_complex = bus_edge_attr[:, YFT_TF_R] + 1j * bus_edge_attr[:, YFT_TF_I] Y_bus_sparse = to_torch_coo_tensor( bus_edge_index, From b5f218ee353270ace3b80d90ba2632deb872310b Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 30 Mar 2026 13:46:53 -0400 Subject: [PATCH 56/62] patch for admittance matrix indicies in PBLoss Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/training/loss.py | 50 ++++++++++++++++++++++++++------ 1 file changed, 41 insertions(+), 9 deletions(-) diff --git a/gridfm_graphkit/training/loss.py b/gridfm_graphkit/training/loss.py index 819b2e6..bb8cf8d 100644 --- a/gridfm_graphkit/training/loss.py +++ b/gridfm_graphkit/training/loss.py @@ -455,17 +455,49 @@ def forward( V_conj = torch.conj(V) # --- Admittance matrix from bus-bus edge attrs --- - # Off-diagonal entries of Y-bus: Y[from][to] = Yft, Y[to][from] = Ytf. - # The dataset stores forward edges with Yft at YFT_TF columns and - # reverse edges with Ytf at the same columns, so indexing YFT_TF_R/I - # gives the correct off-diagonal admittance for both directions. - # (YFF_TT columns hold diagonal-block entries Yff/Ytt which belong on - # the Y-bus diagonal, not at off-diagonal edge positions.) - edge_complex = bus_edge_attr[:, YFT_TF_R] + 1j * bus_edge_attr[:, YFT_TF_I] + # The Y-bus matrix has off-diagonal AND diagonal entries. + # + # Off-diagonal: Y[from][to] = Yft, Y[to][from] = Ytf, stored in the + # YFT_TF columns of the edge attributes. + # + # Diagonal: Y[k][k] = sum of Yff/Ytt for all branches at bus k. + # The dataset stores Yff (forward edges) and Ytt (reverse edges) in + # the YFF_TT columns. For every edge, YFF_TT at the *source* bus + # gives that branch's diagonal contribution at the source. Summing + # over all edges with source == k yields the full branch-diagonal. + # + # The reference project loads a pre-built Y-bus (y_bus_data.parquet) + # that includes self-loops for diagonal entries. Here we reconstruct + # the same structure from per-branch pi-model parameters. + + # Off-diagonal admittance values + edge_offdiag = bus_edge_attr[:, YFT_TF_R] + 1j * bus_edge_attr[:, YFT_TF_I] + + # Diagonal: aggregate Yff/Ytt to source bus of each edge + Y_diag_r = scatter_add( + bus_edge_attr[:, YFF_TT_R], + bus_edge_index[0], + dim=0, + dim_size=num_bus, + ) + Y_diag_i = scatter_add( + bus_edge_attr[:, YFF_TT_I], + bus_edge_index[0], + dim=0, + dim_size=num_bus, + ) + Y_diag = Y_diag_r + 1j * Y_diag_i + + # Build complete Y-bus: off-diagonal edges + self-loops for diagonal + diag_idx = torch.arange(num_bus, device=bus_edge_index.device) + full_edge_index = torch.cat( + [bus_edge_index, torch.stack([diag_idx, diag_idx])], dim=1, + ) + full_edge_values = torch.cat([edge_offdiag, Y_diag]) Y_bus_sparse = to_torch_coo_tensor( - bus_edge_index, - edge_complex, + full_edge_index, + full_edge_values, size=(num_bus, num_bus), ) Y_bus_conj = torch.conj(Y_bus_sparse) From 823438f13ef8cc87fe601f45627883fbb814f7cf Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Mon, 30 Mar 2026 14:21:47 -0400 Subject: [PATCH 57/62] PBLoss support for random masking Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/training/loss.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/gridfm_graphkit/training/loss.py b/gridfm_graphkit/training/loss.py index bb8cf8d..1966a39 100644 --- a/gridfm_graphkit/training/loss.py +++ b/gridfm_graphkit/training/loss.py @@ -524,8 +524,18 @@ def forward( ) > 0 Pg_per_bus = torch.where(any_gen_masked, pred_bus[:, PG_OUT], target_pg_agg) - Pd = target_bus[:, PD_H] - Qd = target_bus[:, QD_H] + # Pd, Qd: use prediction where masked, target where known. + # For deterministic PF/OPF masks PD/QD are never masked, so this + # is equivalent to always using target. For random masking this + # lets PBE provide gradient signal for PD/QD reconstruction. + if pred_bus.size(1) > PD_OUT: + Pd = torch.where(mask_bus[:, PD_H], pred_bus[:, PD_OUT], target_bus[:, PD_H]) + else: + Pd = target_bus[:, PD_H] + if pred_bus.size(1) > QD_OUT: + Qd = torch.where(mask_bus[:, QD_H], pred_bus[:, QD_OUT], target_bus[:, QD_H]) + else: + Qd = target_bus[:, QD_H] # Qg: use prediction if the model predicts it, else use target if pred_bus.size(1) > QG_OUT: From 821a40ca2e83051fe8eff49ab79d5f49887b1aec Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Tue, 31 Mar 2026 13:02:40 -0400 Subject: [PATCH 58/62] slice mse features Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/tasks/opf_task.py | 11 ++++++++--- gridfm_graphkit/tasks/pf_task.py | 11 ++++++++--- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/gridfm_graphkit/tasks/opf_task.py b/gridfm_graphkit/tasks/opf_task.py index b28c5a0..e2f0bcd 100644 --- a/gridfm_graphkit/tasks/opf_task.py +++ b/gridfm_graphkit/tasks/opf_task.py @@ -233,18 +233,23 @@ def test_step(self, batch, batch_idx, dataloader_idx=0): loss_dict["Active Power Loss"] = final_residual_real_bus.detach() loss_dict["Reactive Power Loss"] = final_residual_imag_bus.detach() + # Slice output to the 4 target columns [VM, VA, PG, QG] so that + # models with wider bus output (e.g. GRIT with output_bus_dim=6) + # are compared correctly against the 4-column target. + output_bus_metrics = output["bus"][:, [VM_OUT, VA_OUT, PG_OUT, QG_OUT]] + mse_PQ = F.mse_loss( - output["bus"][mask_PQ], + output_bus_metrics[mask_PQ], target[mask_PQ], reduction="none", ) mse_PV = F.mse_loss( - output["bus"][mask_PV], + output_bus_metrics[mask_PV], target[mask_PV], reduction="none", ) mse_REF = F.mse_loss( - output["bus"][mask_REF], + output_bus_metrics[mask_REF], target[mask_REF], reduction="none", ) diff --git a/gridfm_graphkit/tasks/pf_task.py b/gridfm_graphkit/tasks/pf_task.py index cdc9d64..c17696e 100644 --- a/gridfm_graphkit/tasks/pf_task.py +++ b/gridfm_graphkit/tasks/pf_task.py @@ -166,18 +166,23 @@ def test_step(self, batch, batch_idx, dataloader_idx=0): loss_dict["PBE Mean"] = pbe_mean.detach() + # Slice output to the 4 target columns [VM, VA, PG, QG] so that + # models with wider bus output (e.g. GRIT with output_bus_dim=6) + # are compared correctly against the 4-column target. + output_bus_metrics = output["bus"][:, [VM_OUT, VA_OUT, PG_OUT, QG_OUT]] + mse_PQ = F.mse_loss( - output["bus"][mask_PQ], + output_bus_metrics[mask_PQ], target[mask_PQ], reduction="none", ) mse_PV = F.mse_loss( - output["bus"][mask_PV], + output_bus_metrics[mask_PV], target[mask_PV], reduction="none", ) mse_REF = F.mse_loss( - output["bus"][mask_REF], + output_bus_metrics[mask_REF], target[mask_REF], reduction="none", ) From 9dabcbf8f11cd75c4e1e338e3806d1c4f0ba11da Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Tue, 31 Mar 2026 13:41:37 -0400 Subject: [PATCH 59/62] adjust denorm for GRIT Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/datasets/normalizers.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/gridfm_graphkit/datasets/normalizers.py b/gridfm_graphkit/datasets/normalizers.py index 6189404..01936a2 100644 --- a/gridfm_graphkit/datasets/normalizers.py +++ b/gridfm_graphkit/datasets/normalizers.py @@ -20,6 +20,8 @@ # Output feature indices PG_OUT, QG_OUT, + PD_OUT, + QD_OUT, PG_OUT_GEN, # Generator feature indices PG_H, @@ -306,8 +308,10 @@ def inverse_output(self, output, batch): gen_output = output["gen"] bus_output[:, PG_OUT] *= self.baseMVA bus_output[:, QG_OUT] *= self.baseMVA - bus_output[:, PD_OUT] *= self.baseMVA # for random masking - bus_output[:, QD_OUT] *= self.baseMVA # for random masking + if bus_output.size(1) > PD_OUT: + bus_output[:, PD_OUT] *= self.baseMVA + if bus_output.size(1) > QD_OUT: + bus_output[:, QD_OUT] *= self.baseMVA gen_output[:, PG_OUT_GEN] *= self.baseMVA def get_stats(self) -> dict: @@ -608,6 +612,10 @@ def inverse_output(self, output, batch): # Scale per-unit power back to MW/Mvar bus_output[:, PG_OUT] *= b_bus bus_output[:, QG_OUT] *= b_bus + if bus_output.size(1) > PD_OUT: + bus_output[:, PD_OUT] *= b_bus + if bus_output.size(1) > QD_OUT: + bus_output[:, QD_OUT] *= b_bus gen_output[:, PG_OUT_GEN] *= b_gen def get_stats(self) -> dict: From f9611037cdba89f61ec332b9859c4e165fd2de6e Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Tue, 31 Mar 2026 14:46:56 -0400 Subject: [PATCH 60/62] update mse Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/training/loss.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/gridfm_graphkit/training/loss.py b/gridfm_graphkit/training/loss.py index 1966a39..960e3ca 100644 --- a/gridfm_graphkit/training/loss.py +++ b/gridfm_graphkit/training/loss.py @@ -100,9 +100,12 @@ def forward( mask_dict, model=None, ): + gen_pred = pred_dict["gen"][:, : (PG_H + 1)] + gen_target = target_dict["gen"][:, : (PG_H + 1)] + mask = mask_dict["gen"][:, : (PG_H + 1)] loss = F.mse_loss( - pred_dict["gen"][mask_dict["gen"][:, : (PG_H + 1)]], - target_dict["gen"][mask_dict["gen"][:, : (PG_H + 1)]], + gen_pred[mask], + gen_target[mask], reduction=self.reduction, ) return {"loss": loss, "Masked generator MSE loss": loss.detach()} From 5738c9533f54f78f8cf086dbf82518f956e4ad69 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Wed, 1 Apr 2026 09:51:47 -0400 Subject: [PATCH 61/62] clamp unsupervised values in evaluate Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/tasks/pf_task.py | 80 +++++++++++++++++++++++++++----- 1 file changed, 68 insertions(+), 12 deletions(-) diff --git a/gridfm_graphkit/tasks/pf_task.py b/gridfm_graphkit/tasks/pf_task.py index c17696e..74d6bee 100644 --- a/gridfm_graphkit/tasks/pf_task.py +++ b/gridfm_graphkit/tasks/pf_task.py @@ -5,6 +5,8 @@ QG_H, VM_H, VA_H, + # Generator feature indices + PG_H, # Output feature indices VM_OUT, VA_OUT, @@ -80,12 +82,35 @@ def test_step(self, batch, batch_idx, dataloader_idx=0): # UN-COMMENT THIS TO CHECK PBE ON GROUND TRUTH # output["bus"] = target - Pft, Qft = branch_flow_layer(output["bus"], bus_edge_index, bus_edge_attr) + # Clamp known (unmasked) values to ground truth, matching the + # PBELoss convention used during training. The model is only + # responsible for predicting masked unknowns; using raw predictions + # for known quantities would inflate the residual. + mask_bus = batch.mask_dict["bus"] + eval_bus = output["bus"].clone() + eval_bus[:, VM_OUT] = torch.where( + mask_bus[:, VM_H], output["bus"][:, VM_OUT], target[:, VM_OUT], + ) + eval_bus[:, VA_OUT] = torch.where( + mask_bus[:, VA_H], output["bus"][:, VA_OUT], target[:, VA_OUT], + ) + gen_pg_masked = batch.mask_dict["gen"][:, PG_H].float() + any_gen_masked = ( + scatter_add(gen_pg_masked, gen_to_bus_index, dim=0, dim_size=num_bus) > 0 + ) + eval_bus[:, PG_OUT] = torch.where( + any_gen_masked, output["bus"][:, PG_OUT], target[:, PG_OUT], + ) + eval_bus[:, QG_OUT] = torch.where( + mask_bus[:, QG_H], output["bus"][:, QG_OUT], target[:, QG_OUT], + ) + + Pft, Qft = branch_flow_layer(eval_bus, bus_edge_index, bus_edge_attr) P_in, Q_in = node_injection_layer(Pft, Qft, bus_edge_index, num_bus) residual_P, residual_Q = node_residuals_layer( P_in, Q_in, - output["bus"], + eval_bus, batch.x_dict["bus"], ) @@ -368,13 +393,52 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): num_bus = batch.x_dict["bus"].size(0) bus_edge_index = batch.edge_index_dict[("bus", "connects", "bus")] bus_edge_attr = batch.edge_attr_dict[("bus", "connects", "bus")] + _, gen_to_bus_index = batch.edge_index_dict[("gen", "connected_to", "bus")] + + agg_gen_on_bus = scatter_add( + batch.y_dict["gen"], + gen_to_bus_index, + dim=0, + dim_size=num_bus, + ) + + # Build target for clamping known (unmasked) values + target = torch.stack( + [ + batch.y_dict["bus"][:, VM_H], + batch.y_dict["bus"][:, VA_H], + agg_gen_on_bus.squeeze(), + batch.y_dict["bus"][:, QG_H], + ], + dim=1, + ) - Pft, Qft = branch_flow_layer(output["bus"], bus_edge_index, bus_edge_attr) + # Clamp known values to ground truth (same as PBELoss during training) + mask_bus = batch.mask_dict["bus"] + eval_bus = output["bus"].clone() + eval_bus[:, VM_OUT] = torch.where( + mask_bus[:, VM_H], output["bus"][:, VM_OUT], target[:, VM_OUT], + ) + eval_bus[:, VA_OUT] = torch.where( + mask_bus[:, VA_H], output["bus"][:, VA_OUT], target[:, VA_OUT], + ) + gen_pg_masked = batch.mask_dict["gen"][:, PG_H].float() + any_gen_masked = ( + scatter_add(gen_pg_masked, gen_to_bus_index, dim=0, dim_size=num_bus) > 0 + ) + eval_bus[:, PG_OUT] = torch.where( + any_gen_masked, output["bus"][:, PG_OUT], target[:, PG_OUT], + ) + eval_bus[:, QG_OUT] = torch.where( + mask_bus[:, QG_H], output["bus"][:, QG_OUT], target[:, QG_OUT], + ) + + Pft, Qft = branch_flow_layer(eval_bus, bus_edge_index, bus_edge_attr) P_in, Q_in = node_injection_layer(Pft, Qft, bus_edge_index, num_bus) residual_P, residual_Q = node_residuals_layer( P_in, Q_in, - output["bus"], + eval_bus, batch.x_dict["bus"], ) residual_P = torch.abs(residual_P) @@ -396,14 +460,6 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): mask_PV = batch.mask_dict["PV"] mask_REF = batch.mask_dict["REF"] - _, gen_to_bus_index = batch.edge_index_dict[("gen", "connected_to", "bus")] - agg_gen_on_bus = scatter_add( - batch.y_dict["gen"], - gen_to_bus_index, - dim=0, - dim_size=num_bus, - ) - return { "scenario": scenario_ids.cpu().numpy(), "bus": local_bus_idx.cpu().numpy(), From b661b85acc756badb29b3b15dd7093eedac11b56 Mon Sep 17 00:00:00 2001 From: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> Date: Wed, 1 Apr 2026 12:42:17 -0400 Subject: [PATCH 62/62] clean up Signed-off-by: Thomas Tolhurst <99353435+ttolhurst@users.noreply.github.com> --- gridfm_graphkit/models/grit_transformer.py | 3 - gridfm_graphkit/tasks/pf_task.py | 145 +++++++++------------ gridfm_graphkit/training/loss.py | 19 ++- 3 files changed, 76 insertions(+), 91 deletions(-) diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py index 3422a21..51767e2 100644 --- a/gridfm_graphkit/models/grit_transformer.py +++ b/gridfm_graphkit/models/grit_transformer.py @@ -236,9 +236,6 @@ def forward(self, batch): Returns: output (Tensor): Output node features of shape [num_nodes, output_dim]. """ - # print('xxxx',batch.x.min(), batch.x.max()) - # print('yyyyy',batch.y.min(), batch.y.max()) - # print('>>>>', batch) for module in self.children(): batch = module(batch) diff --git a/gridfm_graphkit/tasks/pf_task.py b/gridfm_graphkit/tasks/pf_task.py index 74d6bee..46be699 100644 --- a/gridfm_graphkit/tasks/pf_task.py +++ b/gridfm_graphkit/tasks/pf_task.py @@ -36,11 +36,67 @@ import pandas as pd +def _build_bus_target(batch, num_bus): + """Build a 4-column bus-level target tensor [VM, VA, PG_agg, QG]. + + Generator PG is aggregated onto buses via scatter_add so that the + target layout matches the bus head output columns. + """ + _, gen_to_bus_index = batch.edge_index_dict[("gen", "connected_to", "bus")] + agg_gen_on_bus = scatter_add( + batch.y_dict["gen"], + gen_to_bus_index, + dim=0, + dim_size=num_bus, + ) + target = torch.stack( + [ + batch.y_dict["bus"][:, VM_H], + batch.y_dict["bus"][:, VA_H], + agg_gen_on_bus.squeeze(), + batch.y_dict["bus"][:, QG_H], + ], + dim=1, + ) + return target, gen_to_bus_index, agg_gen_on_bus + + +def _clamp_known_to_ground_truth(output_bus, target, batch, gen_to_bus_index, num_bus): + """Replace predicted values with ground truth for known (unmasked) quantities. + + During both training (PBELoss) and evaluation, the model is only + responsible for predicting masked unknowns. Known quantities (e.g. + VM at PV buses, VA at REF, PG at non-slack generators) are clamped to + ground truth so that prediction errors on non-target outputs do not + pollute the power-balance residual. + """ + mask_bus = batch.mask_dict["bus"] + eval_bus = output_bus.clone() + eval_bus[:, VM_OUT] = torch.where( + mask_bus[:, VM_H], output_bus[:, VM_OUT], target[:, VM_OUT], + ) + eval_bus[:, VA_OUT] = torch.where( + mask_bus[:, VA_H], output_bus[:, VA_OUT], target[:, VA_OUT], + ) + gen_pg_masked = batch.mask_dict["gen"][:, PG_H].float() + any_gen_masked = ( + scatter_add(gen_pg_masked, gen_to_bus_index, dim=0, dim_size=num_bus) > 0 + ) + eval_bus[:, PG_OUT] = torch.where( + any_gen_masked, output_bus[:, PG_OUT], target[:, PG_OUT], + ) + eval_bus[:, QG_OUT] = torch.where( + mask_bus[:, QG_H], output_bus[:, QG_OUT], target[:, QG_OUT], + ) + return eval_bus + + @TASK_REGISTRY.register("PowerFlow") class PowerFlowTask(ReconstructionTask): """ - Concrete Optimal Power Flow task. - Extends ReconstructionTask and adds OPF-specific metrics. + Concrete Power Flow task. + Extends ReconstructionTask and adds PF-specific evaluation metrics + (power balance residuals, per-bus-type RMSE). """ def __init__(self, args, data_normalizers): @@ -60,49 +116,10 @@ def test_step(self, batch, batch_idx, dataloader_idx=0): num_bus = batch.x_dict["bus"].size(0) bus_edge_index = batch.edge_index_dict[("bus", "connects", "bus")] bus_edge_attr = batch.edge_attr_dict[("bus", "connects", "bus")] - _, gen_to_bus_index = batch.edge_index_dict[("gen", "connected_to", "bus")] - - agg_gen_on_bus = scatter_add( - batch.y_dict["gen"], - gen_to_bus_index, - dim=0, - dim_size=num_bus, - ) - # output_agg = torch.cat([batch.y_dict["bus"], agg_gen_on_bus], dim=1) - target = torch.stack( - [ - batch.y_dict["bus"][:, VM_H], - batch.y_dict["bus"][:, VA_H], - agg_gen_on_bus.squeeze(), - batch.y_dict["bus"][:, QG_H], - ], - dim=1, - ) - # UN-COMMENT THIS TO CHECK PBE ON GROUND TRUTH - # output["bus"] = target - - # Clamp known (unmasked) values to ground truth, matching the - # PBELoss convention used during training. The model is only - # responsible for predicting masked unknowns; using raw predictions - # for known quantities would inflate the residual. - mask_bus = batch.mask_dict["bus"] - eval_bus = output["bus"].clone() - eval_bus[:, VM_OUT] = torch.where( - mask_bus[:, VM_H], output["bus"][:, VM_OUT], target[:, VM_OUT], - ) - eval_bus[:, VA_OUT] = torch.where( - mask_bus[:, VA_H], output["bus"][:, VA_OUT], target[:, VA_OUT], - ) - gen_pg_masked = batch.mask_dict["gen"][:, PG_H].float() - any_gen_masked = ( - scatter_add(gen_pg_masked, gen_to_bus_index, dim=0, dim_size=num_bus) > 0 - ) - eval_bus[:, PG_OUT] = torch.where( - any_gen_masked, output["bus"][:, PG_OUT], target[:, PG_OUT], - ) - eval_bus[:, QG_OUT] = torch.where( - mask_bus[:, QG_H], output["bus"][:, QG_OUT], target[:, QG_OUT], + target, gen_to_bus_index, agg_gen_on_bus = _build_bus_target(batch, num_bus) + eval_bus = _clamp_known_to_ground_truth( + output["bus"], target, batch, gen_to_bus_index, num_bus, ) Pft, Qft = branch_flow_layer(eval_bus, bus_edge_index, bus_edge_attr) @@ -393,44 +410,10 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): num_bus = batch.x_dict["bus"].size(0) bus_edge_index = batch.edge_index_dict[("bus", "connects", "bus")] bus_edge_attr = batch.edge_attr_dict[("bus", "connects", "bus")] - _, gen_to_bus_index = batch.edge_index_dict[("gen", "connected_to", "bus")] - - agg_gen_on_bus = scatter_add( - batch.y_dict["gen"], - gen_to_bus_index, - dim=0, - dim_size=num_bus, - ) - - # Build target for clamping known (unmasked) values - target = torch.stack( - [ - batch.y_dict["bus"][:, VM_H], - batch.y_dict["bus"][:, VA_H], - agg_gen_on_bus.squeeze(), - batch.y_dict["bus"][:, QG_H], - ], - dim=1, - ) - # Clamp known values to ground truth (same as PBELoss during training) - mask_bus = batch.mask_dict["bus"] - eval_bus = output["bus"].clone() - eval_bus[:, VM_OUT] = torch.where( - mask_bus[:, VM_H], output["bus"][:, VM_OUT], target[:, VM_OUT], - ) - eval_bus[:, VA_OUT] = torch.where( - mask_bus[:, VA_H], output["bus"][:, VA_OUT], target[:, VA_OUT], - ) - gen_pg_masked = batch.mask_dict["gen"][:, PG_H].float() - any_gen_masked = ( - scatter_add(gen_pg_masked, gen_to_bus_index, dim=0, dim_size=num_bus) > 0 - ) - eval_bus[:, PG_OUT] = torch.where( - any_gen_masked, output["bus"][:, PG_OUT], target[:, PG_OUT], - ) - eval_bus[:, QG_OUT] = torch.where( - mask_bus[:, QG_H], output["bus"][:, QG_OUT], target[:, QG_OUT], + target, gen_to_bus_index, agg_gen_on_bus = _build_bus_target(batch, num_bus) + eval_bus = _clamp_known_to_ground_truth( + output["bus"], target, batch, gen_to_bus_index, num_bus, ) Pft, Qft = branch_flow_layer(eval_bus, bus_edge_index, bus_edge_attr) diff --git a/gridfm_graphkit/training/loss.py b/gridfm_graphkit/training/loss.py index 960e3ca..d700b89 100644 --- a/gridfm_graphkit/training/loss.py +++ b/gridfm_graphkit/training/loss.py @@ -441,7 +441,14 @@ def forward( bus_edge_attr = edge_attr_dict[("bus", "connects", "bus")] mask_bus = mask_dict["bus"] - # --- Voltage: use prediction where masked, target where known --- + # --- Clamp known values to ground truth --- + # In power flow, certain variables are "known" (unmasked) at each + # bus type (e.g. VM at PV buses, VA at REF). The model only needs + # to predict *masked* unknowns; for everything else we substitute + # the ground truth so that errors in non-target outputs do not + # pollute the physics loss. This matches the reference's + # ``temp_pred[unmasked] = target[unmasked]`` convention. + Vm_pred = pred_bus[:, VM_OUT] Va_pred = pred_bus[:, VA_OUT] Vm_target = target_bus[:, VM_H] @@ -527,10 +534,10 @@ def forward( ) > 0 Pg_per_bus = torch.where(any_gen_masked, pred_bus[:, PG_OUT], target_pg_agg) - # Pd, Qd: use prediction where masked, target where known. - # For deterministic PF/OPF masks PD/QD are never masked, so this - # is equivalent to always using target. For random masking this - # lets PBE provide gradient signal for PD/QD reconstruction. + # Pd, Qd, Qg: same clamp-to-ground-truth logic. The size guard + # (``pred_bus.size(1) > *_OUT``) handles models with a narrow bus + # head (e.g. output_bus_dim=4) that don't predict PD/QD/QG; in that + # case the target is always used. if pred_bus.size(1) > PD_OUT: Pd = torch.where(mask_bus[:, PD_H], pred_bus[:, PD_OUT], target_bus[:, PD_H]) else: @@ -539,8 +546,6 @@ def forward( Qd = torch.where(mask_bus[:, QD_H], pred_bus[:, QD_OUT], target_bus[:, QD_H]) else: Qd = target_bus[:, QD_H] - - # Qg: use prediction if the model predicts it, else use target if pred_bus.size(1) > QG_OUT: Qg = torch.where(mask_bus[:, QG_H], pred_bus[:, QG_OUT], target_bus[:, QG_H]) else: