diff --git a/README.md b/README.md index ad85f05..aa57d94 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,7 @@ source venv/bin/activate Install gridfm-graphkit in editable mode ```bash pip install -e . +pip install torch_sparse torch_scatter -f https://data.pyg.org/whl/torch-2.6.0+cu124.html ``` Get PyTorch + CUDA version for torch-scatter diff --git a/examples/config/GRIT_PF_datakit_case14.yaml b/examples/config/GRIT_PF_datakit_case14.yaml new file mode 100644 index 0000000..cef4b50 --- /dev/null +++ b/examples/config/GRIT_PF_datakit_case14.yaml @@ -0,0 +1,92 @@ +callbacks: + patience: 100 + tol: 0 +task: + task_name: PowerFlow +data: + baseMVA: 100 + mask_type: rnd # or determinstic + mask_ratio: 0.5 # for random masking only + mask_value: -1 + normalization: HeteroDataMVANormalizer + networks: + - case14_ieee + scenarios: + - 5000 + test_ratio: 0.1 + val_ratio: 0.1 + workers: 4 + posenc_RRWP: + enable: false + ksteps: 21 + posenc_RWSE: + enable: true + kernel: + times: 21 +model: + attention_head: 8 + dropout: 0.1 + # edge_dim must match the bus-bus edge feature count after transforms + # (P_E, Q_E, YFF_TT_R, YFF_TT_I, YFT_TF_R, YFT_TF_I, TAP, ANG_MIN, ANG_MAX, RATE_A) + edge_dim: 10 + hidden_size: 496 + # input_dim = bus feature count + aggregated PG (used by GRIT core FeatureEncoder) + input_dim: 16 + # Hetero adapter head dimensions + input_bus_dim: 16 + input_gen_dim: 6 + output_bus_dim: 6 # [VM, VA, PG, QG, PD, QD] + output_gen_dim: 0 # PG predicted at bus level; no per-generator head needed + num_layers: 7 + type: GRIT + act: relu + encoder: + node_encoder: true + edge_encoder: true + node_encoder_name: RWSE + node_encoder_bn: true + edge_encoder_bn: true + posenc_RWSE: + # kernel.times is synced automatically from data.posenc_RWSE.kernel.times + pe_dim: 20 + raw_norm_type: batchnorm + gt: + layer_type: GritTransformer + # dim_hidden is synced automatically from model.hidden_size + layer_norm: false + batch_norm: true + update_e: true + attn_dropout: 0.2 + attn: + clamp: 5. + act: relu + full_attn: true + edge_enhance: true + O_e: true + norm_e: true + signed_sqrt: true + bn_momentum: 0.1 + bn_no_runner: false +optimizer: + beta1: 0.9 + beta2: 0.999 + learning_rate: 0.0001 + lr_decay: 0.7 + lr_patience: 10 +seed: 0 +training: + batch_size: 8 + epochs: 500 + loss_weights: + - 0.99 + - 0.01 + losses: + - PBE + - MaskedReconstructionMSE + loss_args: + - {} + - {} + accelerator: auto + devices: auto + strategy: auto +verbose: true diff --git a/gridfm_graphkit/cli.py b/gridfm_graphkit/cli.py index 23fdf08..ce36662 100644 --- a/gridfm_graphkit/cli.py +++ b/gridfm_graphkit/cli.py @@ -79,6 +79,11 @@ def main_cli(args): callbacks=get_training_callbacks(config_args), profiler=profiler, ) + + # print('******model*****') + # print(model) + # print('******model*****') + if args.command == "train" or args.command == "finetune": trainer.fit(model=model, datamodule=litGrid) diff --git a/gridfm_graphkit/datasets/globals.py b/gridfm_graphkit/datasets/globals.py index ab3c7e3..9627cd8 100644 --- a/gridfm_graphkit/datasets/globals.py +++ b/gridfm_graphkit/datasets/globals.py @@ -24,6 +24,8 @@ VA_OUT = 1 PG_OUT = 2 QG_OUT = 3 +PD_OUT = 4 # for random masking +QD_OUT = 5 # for random masking PG_OUT_GEN = 0 diff --git a/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py b/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py index 4ac0125..e6a4cfd 100644 --- a/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py +++ b/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py @@ -14,6 +14,11 @@ split_dataset_by_load_scenario_idx, ) from gridfm_graphkit.datasets.powergrid_hetero_dataset import HeteroGridDatasetDisk + +from gridfm_graphkit.datasets.posenc_stats import ComputePosencStat + +import torch_geometric.transforms as T + import numpy as np import random import warnings @@ -149,6 +154,24 @@ def setup(self, stage: str): data_normalizer=data_normalizer, transform=get_task_transforms(args=self.args), ) + + if ('posenc_RRWP' in self.args.data) and self.args.data.posenc_RRWP.enable: + pe_transform = ComputePosencStat(pe_types=['RRWP'], + cfg=self.args.data + ) + if dataset.transform is None: + dataset.transform = pe_transform + else: + dataset.transform = T.Compose([pe_transform, dataset.transform]) + if ('posenc_RWSE' in self.args.data) and self.args.data.posenc_RWSE.enable: + pe_transform = ComputePosencStat(pe_types=['RWSE'], + cfg=self.args.data + ) + if dataset.transform is None: + dataset.transform = pe_transform + else: + dataset.transform = T.Compose([pe_transform, dataset.transform]) + self.datasets.append(dataset) num_scenarios = self.args.data.scenarios[i] diff --git a/gridfm_graphkit/datasets/masking.py b/gridfm_graphkit/datasets/masking.py index b615e0c..01924a2 100644 --- a/gridfm_graphkit/datasets/masking.py +++ b/gridfm_graphkit/datasets/masking.py @@ -33,6 +33,60 @@ from torch_geometric.nn import MessagePassing +class AddRandomHeteroMask(BaseTransform): + """Creates random masks for self-supervised pretraining on heterogeneous power grid graphs. + + Each selected feature dimension is independently masked per node/edge with + probability ``mask_ratio``. Masked bus features: PD, QD, VM, VA, QG. + Masked gen features: PG. Masked branch features: P_E, Q_E. + + The output ``data.mask_dict`` has the same structure as the deterministic + PF / OPF masks so that downstream losses (``MaskedReconstructionMSE``, + ``PBELoss``, etc.) work without modification. + """ + + def __init__(self, mask_ratio=0.5): + super().__init__() + self.mask_ratio = mask_ratio + + def forward(self, data): + bus_x = data.x_dict["bus"] + gen_x = data.x_dict["gen"] + + # Bus type indicators (needed by losses and test metrics) + mask_PQ = bus_x[:, PQ_H] == 1 + mask_PV = bus_x[:, PV_H] == 1 + mask_REF = bus_x[:, REF_H] == 1 + + # Random bus mask on variable features the model reconstructs + mask_bus = torch.zeros_like(bus_x, dtype=torch.bool) + n_bus = bus_x.size(0) + for feat_idx in (PD_H, QD_H, VM_H, VA_H, QG_H): + mask_bus[:, feat_idx] = torch.rand(n_bus) < self.mask_ratio + + # Random gen mask on PG + mask_gen = torch.zeros_like(gen_x, dtype=torch.bool) + mask_gen[:, PG_H] = torch.rand(gen_x.size(0)) < self.mask_ratio + + # Random branch mask on flow features + branch_attr = data.edge_attr_dict[("bus", "connects", "bus")] + mask_branch = torch.zeros_like(branch_attr, dtype=torch.bool) + n_edge = branch_attr.size(0) + for feat_idx in (P_E, Q_E): + mask_branch[:, feat_idx] = torch.rand(n_edge) < self.mask_ratio + + data.mask_dict = { + "bus": mask_bus, + "gen": mask_gen, + "branch": mask_branch, + "PQ": mask_PQ, + "PV": mask_PV, + "REF": mask_REF, + } + + return data + + class AddPFHeteroMask(BaseTransform): """Creates masks for a heterogeneous power flow graph.""" diff --git a/gridfm_graphkit/datasets/normalizers.py b/gridfm_graphkit/datasets/normalizers.py index 11601a6..01936a2 100644 --- a/gridfm_graphkit/datasets/normalizers.py +++ b/gridfm_graphkit/datasets/normalizers.py @@ -20,6 +20,8 @@ # Output feature indices PG_OUT, QG_OUT, + PD_OUT, + QD_OUT, PG_OUT_GEN, # Generator feature indices PG_H, @@ -306,6 +308,10 @@ def inverse_output(self, output, batch): gen_output = output["gen"] bus_output[:, PG_OUT] *= self.baseMVA bus_output[:, QG_OUT] *= self.baseMVA + if bus_output.size(1) > PD_OUT: + bus_output[:, PD_OUT] *= self.baseMVA + if bus_output.size(1) > QD_OUT: + bus_output[:, QD_OUT] *= self.baseMVA gen_output[:, PG_OUT_GEN] *= self.baseMVA def get_stats(self) -> dict: @@ -606,6 +612,10 @@ def inverse_output(self, output, batch): # Scale per-unit power back to MW/Mvar bus_output[:, PG_OUT] *= b_bus bus_output[:, QG_OUT] *= b_bus + if bus_output.size(1) > PD_OUT: + bus_output[:, PD_OUT] *= b_bus + if bus_output.size(1) > QD_OUT: + bus_output[:, QD_OUT] *= b_bus gen_output[:, PG_OUT_GEN] *= b_gen def get_stats(self) -> dict: diff --git a/gridfm_graphkit/datasets/posenc_stats.py b/gridfm_graphkit/datasets/posenc_stats.py new file mode 100644 index 0000000..5263b48 --- /dev/null +++ b/gridfm_graphkit/datasets/posenc_stats.py @@ -0,0 +1,166 @@ +from copy import deepcopy + +import numpy as np +import torch +import torch.nn.functional as F + +from torch_geometric.utils import (get_laplacian, to_scipy_sparse_matrix, + to_undirected, to_dense_adj) +from torch_geometric.utils.num_nodes import maybe_num_nodes +from torch_scatter import scatter_add +from functools import partial +from gridfm_graphkit.datasets.rrwp import add_full_rrwp + +from torch_geometric.transforms import BaseTransform +from torch_geometric.data import Data, HeteroData +from typing import Any + +from torch_geometric.utils.num_nodes import maybe_num_nodes + +def compute_posenc_stats(data, pe_types, cfg): + """Precompute positional encodings for the given graph. + Supported PE statistics to precompute in original implementation, + selected by `pe_types`: + 'LapPE': Laplacian eigen-decomposition. + 'RWSE': Random walk landing probabilities (diagonals of RW matrices). + 'HKfullPE': Full heat kernels and their diagonals. (NOT IMPLEMENTED) + 'HKdiagSE': Diagonals of heat kernel diffusion. + 'ElstaticSE': Kernel based on the electrostatic interaction between nodes. + 'RRWP': Relative Random Walk Probabilities PE (Ours, for GRIT) + Args: + data: PyG graph + pe_types: Positional encoding types to precompute statistics for. + This can also be a combination, e.g. 'eigen+rw_landing' + is_undirected: True if the graph is expected to be undirected + cfg: Main configuration node + + Returns: + Extended PyG Data object. + """ + # Verify PE types. + for t in pe_types: + if t not in ['LapPE', 'EquivStableLapPE', 'SignNet', + 'RWSE', 'HKdiagSE', 'HKfullPE', 'ElstaticSE','RRWP']: + raise ValueError(f"Unexpected PE stats selection {t} in {pe_types}") + + if 'RRWP' in pe_types: + param = cfg.posenc_RRWP + transform = partial(add_full_rrwp, + walk_length=param.ksteps, + attr_name_abs="rrwp", + attr_name_rel="rrwp", + add_identity=True + ) + data = transform(data) + + # Random Walks. + if 'RWSE' in pe_types: + kernel_param = cfg.posenc_RWSE.kernel + if hasattr(data, 'num_nodes'): + N = data.num_nodes # Explicitly given number of nodes, e.g. ogbg-ppa + else: + N = data.x.shape[0] # Number of nodes, including disconnected nodes. + if kernel_param.times == 0: + raise ValueError("List of kernel times required for RWSE") + rw_landing = get_rw_landing_probs( + ksteps=[xx + 1 for xx in range(kernel_param.times)], + edge_index=data.edge_index, + num_nodes=N + ) + data.pestat_RWSE = rw_landing + + return data + + + +def get_rw_landing_probs(ksteps, edge_index, edge_weight=None, + num_nodes=None, space_dim=0): + """Compute Random Walk landing probabilities for given list of K steps. + + Args: + ksteps: List of k-steps for which to compute the RW landings + edge_index: PyG sparse representation of the graph + edge_weight: (optional) Edge weights + num_nodes: (optional) Number of nodes in the graph + space_dim: (optional) Estimated dimensionality of the space. Used to + correct the random-walk diagonal by a factor `k^(space_dim/2)`. + In euclidean space, this correction means that the height of + the gaussian distribution stays almost constant across the number of + steps, if `space_dim` is the dimension of the euclidean space. + + Returns: + 2D Tensor with shape (num_nodes, len(ksteps)) with RW landing probs + """ + if edge_weight is None: + edge_weight = torch.ones(edge_index.size(1), device=edge_index.device) + num_nodes = maybe_num_nodes(edge_index, num_nodes) + source, dest = edge_index[0], edge_index[1] + deg = scatter_add(edge_weight, source, dim=0, dim_size=num_nodes) # Out degrees. + deg_inv = deg.pow(-1.) + deg_inv.masked_fill_(deg_inv == float('inf'), 0) + + if edge_index.numel() == 0: + P = edge_index.new_zeros((1, num_nodes, num_nodes)) + else: + # P = D^-1 * A + P = torch.diag(deg_inv) @ to_dense_adj(edge_index, max_num_nodes=num_nodes) # 1 x (Num nodes) x (Num nodes) + rws = [] + if ksteps == list(range(min(ksteps), max(ksteps) + 1)): + # Efficient way if ksteps are a consecutive sequence (most of the time the case) + Pk = P.clone().detach().matrix_power(min(ksteps)) + for k in range(min(ksteps), max(ksteps) + 1): + rws.append(torch.diagonal(Pk, dim1=-2, dim2=-1) * \ + (k ** (space_dim / 2))) + Pk = Pk @ P + else: + # Explicitly raising P to power k for each k \in ksteps. + for k in ksteps: + rws.append(torch.diagonal(P.matrix_power(k), dim1=-2, dim2=-1) * \ + (k ** (space_dim / 2))) + rw_landing = torch.cat(rws, dim=0).transpose(0, 1) # (Num nodes) x (K steps) + + return rw_landing + +class ComputePosencStat(BaseTransform): + def __init__(self, pe_types, cfg): + self.pe_types = pe_types + self.cfg = cfg + + def forward(self, data: Any) -> Any: + pass + + def __call__(self, data) -> Data: + if isinstance(data, HeteroData): + return self._call_hetero(data) + + data = compute_posenc_stats(data, + pe_types=self.pe_types, + cfg=self.cfg + ) + return data + + def _call_hetero(self, data: HeteroData) -> HeteroData: + """Compute PE on the bus-only subgraph and store results on data['bus'].""" + bus_data = Data( + x=data["bus"].x, + edge_index=data["bus", "connects", "bus"].edge_index, + num_nodes=data["bus"].num_nodes, + ) + if hasattr(data["bus", "connects", "bus"], "edge_weight"): + bus_data.edge_weight = data["bus", "connects", "bus"].edge_weight + + bus_data = compute_posenc_stats( + bus_data, pe_types=self.pe_types, cfg=self.cfg, + ) + + # Copy computed PE attributes back onto the HeteroData bus store + pe_attrs = [ + "pestat_RWSE", # RWSE + "rrwp", "rrwp_index", "rrwp_val", # RRWP + "log_deg", "deg", # degree info from RRWP + ] + for attr in pe_attrs: + if hasattr(bus_data, attr): + data["bus"][attr] = getattr(bus_data, attr) + + return data diff --git a/gridfm_graphkit/datasets/rrwp.py b/gridfm_graphkit/datasets/rrwp.py new file mode 100644 index 0000000..acbe112 --- /dev/null +++ b/gridfm_graphkit/datasets/rrwp.py @@ -0,0 +1,87 @@ +from typing import Any, Optional +import torch +import torch.nn.functional as F +from torch_geometric.data import Data +from torch_sparse import SparseTensor + + +def add_node_attr(data: Data, value: Any, + attr_name: Optional[str] = None) -> Data: + if attr_name is None: + if 'x' in data: + x = data.x.view(-1, 1) if data.x.dim() == 1 else data.x + data.x = torch.cat([x, value.to(x.device, x.dtype)], dim=-1) + else: + data.x = value + else: + data[attr_name] = value + + return data + + + +@torch.no_grad() +def add_full_rrwp(data, + walk_length=8, + attr_name_abs="rrwp", # name: 'rrwp' + attr_name_rel="rrwp", # name: ('rrwp_idx', 'rrwp_val') + add_identity=True, + spd=False, + **kwargs + ): + num_nodes = data.num_nodes + edge_index, edge_weight = data.edge_index, data.edge_weight + + adj = SparseTensor.from_edge_index(edge_index, edge_weight, + sparse_sizes=(num_nodes, num_nodes), + ) + + # Compute D^{-1} A: + deg = adj.sum(dim=1) + deg_inv = 1.0 / adj.sum(dim=1) + deg_inv[deg_inv == float('inf')] = 0 + adj = adj * deg_inv.view(-1, 1) + adj = adj.to_dense() + + pe_list = [] + i = 0 + if add_identity: + pe_list.append(torch.eye(num_nodes, dtype=torch.float)) + i = i + 1 + + out = adj + pe_list.append(adj) + + if walk_length > 2: + for j in range(i + 1, walk_length): + out = out @ adj + pe_list.append(out) + + pe = torch.stack(pe_list, dim=-1) # n x n x k + + abs_pe = pe.diagonal().transpose(0, 1) # n x k + + rel_pe = SparseTensor.from_dense(pe, has_value=True) + rel_pe_row, rel_pe_col, rel_pe_val = rel_pe.coo() + # rel_pe_idx = torch.stack([rel_pe_row, rel_pe_col], dim=0) + rel_pe_idx = torch.stack([rel_pe_col, rel_pe_row], dim=0) + # the framework of GRIT performing right-mul while adj is row-normalized, + # need to switch the order or row and col. + # note: both can work but the current version is more reasonable. + + + if spd: + spd_idx = walk_length - torch.arange(walk_length) + val = (rel_pe_val > 0).type(torch.float) * spd_idx.unsqueeze(0) + val = torch.argmax(val, dim=-1) + rel_pe_val = F.one_hot(val, walk_length).type(torch.float) + abs_pe = torch.zeros_like(abs_pe) + + data = add_node_attr(data, abs_pe, attr_name=attr_name_abs) + data = add_node_attr(data, rel_pe_idx, attr_name=f"{attr_name_rel}_index") + data = add_node_attr(data, rel_pe_val, attr_name=f"{attr_name_rel}_val") + data.log_deg = torch.log(deg + 1) + data.deg = deg.type(torch.long) + + return data + diff --git a/gridfm_graphkit/datasets/task_transforms.py b/gridfm_graphkit/datasets/task_transforms.py index eaaca66..d6f1b8c 100644 --- a/gridfm_graphkit/datasets/task_transforms.py +++ b/gridfm_graphkit/datasets/task_transforms.py @@ -8,6 +8,7 @@ from gridfm_graphkit.datasets.masking import ( AddOPFHeteroMask, AddPFHeteroMask, + AddRandomHeteroMask, SimulateMeasurements, ) from gridfm_graphkit.io.registries import TRANSFORM_REGISTRY @@ -20,7 +21,13 @@ def __init__(self, args): transforms.append(RemoveInactiveBranches()) transforms.append(RemoveInactiveGenerators()) - transforms.append(AddPFHeteroMask()) + + mask_type = getattr(args.data, "mask_type", None) + if mask_type == "rnd": + transforms.append(AddRandomHeteroMask(mask_ratio=args.data.mask_ratio)) + else: + transforms.append(AddPFHeteroMask()) + transforms.append(ApplyMasking(args=args)) # Pass the list of transforms to Compose @@ -34,7 +41,13 @@ def __init__(self, args): transforms.append(RemoveInactiveBranches()) transforms.append(RemoveInactiveGenerators()) - transforms.append(AddOPFHeteroMask()) + + mask_type = getattr(args.data, "mask_type", None) + if mask_type == "rnd": + transforms.append(AddRandomHeteroMask(mask_ratio=args.data.mask_ratio)) + else: + transforms.append(AddOPFHeteroMask()) + transforms.append(ApplyMasking(args=args)) # Pass the list of transforms to Compose diff --git a/gridfm_graphkit/models/__init__.py b/gridfm_graphkit/models/__init__.py index f824535..b30680e 100644 --- a/gridfm_graphkit/models/__init__.py +++ b/gridfm_graphkit/models/__init__.py @@ -1,4 +1,6 @@ + from gridfm_graphkit.models.gnn_heterogeneous_gns import GNS_heterogeneous +from gridfm_graphkit.models.grit_transformer import GritHeteroAdapter from gridfm_graphkit.models.utils import ( PhysicsDecoderOPF, PhysicsDecoderPF, @@ -7,7 +9,9 @@ __all__ = [ "GNS_heterogeneous", + "GritHeteroAdapter", "PhysicsDecoderOPF", "PhysicsDecoderPF", "PhysicsDecoderSE", ] + diff --git a/gridfm_graphkit/models/gnn_heterogeneous_gns.py b/gridfm_graphkit/models/gnn_heterogeneous_gns.py index 1073560..e274a1c 100644 --- a/gridfm_graphkit/models/gnn_heterogeneous_gns.py +++ b/gridfm_graphkit/models/gnn_heterogeneous_gns.py @@ -146,14 +146,20 @@ def __init__(self, args) -> None: # container for monitoring residual norms per layer and type self.layer_residuals = {} - def forward(self, x_dict, edge_index_dict, edge_attr_dict, mask_dict): + def forward(self, batch): """ - x_dict: {"bus": Tensor[num_bus, bus_feat], "gen": Tensor[num_gen, gen_feat]} - edge_index_dict: keys like ("bus","connects","bus"), ("gen","connected_to","bus"), ("bus","connected_to","gen") - edge_attr_dict: same keys -> edge attributes (bus-bus requires G,B) - batch_dict: dict mapping node types to batch tensors (if using batching). Not used heavily here but kept for API parity. - mask: optional mask per node (applies when computing residuals) + Accepts a PyG HeteroData batch and extracts the required tensors. + + batch: HeteroData/Batch containing: + x_dict: {"bus": Tensor[num_bus, bus_feat], "gen": Tensor[num_gen, gen_feat]} + edge_index_dict: keys like ("bus","connects","bus"), ("gen","connected_to","bus"), ("bus","connected_to","gen") + edge_attr_dict: same keys -> edge attributes (bus-bus requires G,B) + mask_dict: dict mapping node/bus types to mask tensors """ + x_dict = batch.x_dict + edge_index_dict = batch.edge_index_dict + edge_attr_dict = batch.edge_attr_dict + mask_dict = batch.mask_dict self.layer_residuals = {} diff --git a/gridfm_graphkit/models/grit_layer.py b/gridfm_graphkit/models/grit_layer.py new file mode 100644 index 0000000..a1ffc4a --- /dev/null +++ b/gridfm_graphkit/models/grit_layer.py @@ -0,0 +1,347 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch_geometric as pyg +from torch_geometric.utils.num_nodes import maybe_num_nodes +from torch_scatter import scatter, scatter_max, scatter_add + +import opt_einsum as oe + +import warnings + + +def pyg_softmax(src, index, num_nodes=None): + """ + Computes a sparsely evaluated softmax. + Given a value tensor :attr:`src`, this function first groups the values + along the first dimension based on the indices specified in :attr:`index`, + and then proceeds to compute the softmax individually for each group. + + Args: + src (Tensor): The source tensor. + index (LongTensor): The indices of elements for applying the softmax. + num_nodes (int, optional): The number of nodes, *i.e.* + :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`) + + Returns: + out (Tensor) + """ + + num_nodes = maybe_num_nodes(index, num_nodes) + + out = src - scatter_max(src, index, dim=0, dim_size=num_nodes)[0][index] + out = out.exp() + out = out / ( + scatter_add(out, index, dim=0, dim_size=num_nodes)[index] + 1e-16) + + return out + + + +class MultiHeadAttentionLayerGritSparse(nn.Module): + """ + Attention Computation for GRIT + """ + + def __init__(self, in_dim, out_dim, num_heads, use_bias, + clamp=5., dropout=0., act=None, + edge_enhance=True, + sqrt_relu=False, + signed_sqrt=True, + cfg={}, + **kwargs): + super().__init__() + + self.out_dim = out_dim + self.num_heads = num_heads + self.dropout = nn.Dropout(dropout) + self.clamp = np.abs(clamp) if clamp is not None else None + self.edge_enhance = edge_enhance + + self.Q = nn.Linear(in_dim, out_dim * num_heads, bias=True) + self.K = nn.Linear(in_dim, out_dim * num_heads, bias=use_bias) + self.E = nn.Linear(in_dim, out_dim * num_heads * 2, bias=True) + self.V = nn.Linear(in_dim, out_dim * num_heads, bias=use_bias) + nn.init.xavier_normal_(self.Q.weight) + nn.init.xavier_normal_(self.K.weight) + nn.init.xavier_normal_(self.E.weight) + nn.init.xavier_normal_(self.V.weight) + + self.Aw = nn.Parameter(torch.zeros(self.out_dim, self.num_heads, 1), requires_grad=True) + nn.init.xavier_normal_(self.Aw) + + if act is None: + self.act = nn.Identity() + else: + self.act = nn.ReLU() + + if self.edge_enhance: + self.VeRow = nn.Parameter(torch.zeros(self.out_dim, self.num_heads, self.out_dim), requires_grad=True) + nn.init.xavier_normal_(self.VeRow) + + def propagate_attention(self, batch): + src = batch.K_h[batch.edge_index[0]] # (num relative) x num_heads x out_dim + dest = batch.Q_h[batch.edge_index[1]] # (num relative) x num_heads x out_dim + score = src + dest # element-wise multiplication + + if batch.get("E", None) is not None: + batch.E = batch.E.view(-1, self.num_heads, self.out_dim * 2) + E_w, E_b = batch.E[:, :, :self.out_dim], batch.E[:, :, self.out_dim:] + # (num relative) x num_heads x out_dim + score = score * E_w + score = torch.sqrt(torch.relu(score)) - torch.sqrt(torch.relu(-score)) + score = score + E_b + + score = self.act(score) + e_t = score + + # output edge + if batch.get("E", None) is not None: + batch.wE = score.flatten(1) + + # final attn + score = oe.contract("ehd, dhc->ehc", score, self.Aw, backend="torch") + if self.clamp is not None: + score = torch.clamp(score, min=-self.clamp, max=self.clamp) + + raw_attn = score + score = pyg_softmax(score, batch.edge_index[1]) # (num relative) x num_heads x 1 + score = self.dropout(score) + batch.attn = score + + # Aggregate with Attn-Score + msg = batch.V_h[batch.edge_index[0]] * score # (num relative) x num_heads x out_dim + batch.wV = torch.zeros_like(batch.V_h) # (num nodes in batch) x num_heads x out_dim + scatter(msg, batch.edge_index[1], dim=0, out=batch.wV, reduce='add') + + if self.edge_enhance and batch.E is not None: + rowV = scatter(e_t * score, batch.edge_index[1], dim=0, reduce="add") + rowV = oe.contract("nhd, dhc -> nhc", rowV, self.VeRow, backend="torch") + batch.wV = batch.wV + rowV + + def forward(self, batch): + Q_h = self.Q(batch.x) + K_h = self.K(batch.x) + + V_h = self.V(batch.x) + if batch.get("edge_attr", None) is not None: + batch.E = self.E(batch.edge_attr) + else: + batch.E = None + + batch.Q_h = Q_h.view(-1, self.num_heads, self.out_dim) + batch.K_h = K_h.view(-1, self.num_heads, self.out_dim) + batch.V_h = V_h.view(-1, self.num_heads, self.out_dim) + self.propagate_attention(batch) + h_out = batch.wV + e_out = batch.get('wE', None) + + return h_out, e_out + + +class GritTransformerLayer(nn.Module): + """ + Transformer Layer for GRIT + """ + def __init__(self, in_dim, out_dim, num_heads, + dropout=0.0, + attn_dropout=0.0, + layer_norm=False, batch_norm=True, + residual=True, + act='relu', + norm_e=True, + O_e=True, + cfg=dict(), + **kwargs): + super().__init__() + + self.debug = False + self.in_channels = in_dim + self.out_channels = out_dim + self.in_dim = in_dim + self.out_dim = out_dim + self.num_heads = num_heads + self.dropout = dropout + self.residual = residual + self.layer_norm = layer_norm + self.batch_norm = batch_norm + + # ------- + self.update_e = getattr(cfg.attn, "update_e", True) + self.bn_momentum = cfg.attn.bn_momentum + self.bn_no_runner = cfg.attn.bn_no_runner + self.rezero = getattr(cfg.attn, "rezero", False) + + if act is not None: + self.act = nn.ReLU() + else: + self.act = nn.Identity() + + if getattr(cfg, "attn", None) is None: + cfg.attn = dict() + self.use_attn = getattr(cfg.attn, "use", True) + self.deg_scaler = getattr(cfg.attn, "deg_scaler", True) + + self.attention = MultiHeadAttentionLayerGritSparse( + in_dim=in_dim, + out_dim=out_dim // num_heads, + num_heads=num_heads, + use_bias=getattr(cfg.attn, "use_bias", False), + dropout=attn_dropout, + clamp=getattr(cfg.attn, "clamp", 5.), + act=getattr(cfg.attn, "act", "relu"), + edge_enhance=getattr(cfg.attn, "edge_enhance", True), + sqrt_relu=getattr(cfg.attn, "sqrt_relu", False), + signed_sqrt=getattr(cfg.attn, "signed_sqrt", False), + scaled_attn =getattr(cfg.attn,"scaled_attn", False), + no_qk=getattr(cfg.attn, "no_qk", False), + ) + + if getattr(cfg.attn, 'graphormer_attn', False): + self.attention = MultiHeadAttentionLayerGraphormerSparse( + in_dim=in_dim, + out_dim=out_dim // num_heads, + num_heads=num_heads, + use_bias=getattr(cfg.attn, "use_bias", False), + dropout=attn_dropout, + clamp=getattr(cfg.attn, "clamp", 5.), + act=getattr(cfg.attn, "act", "relu"), + edge_enhance=True, + sqrt_relu=getattr(cfg.attn, "sqrt_relu", False), + signed_sqrt=getattr(cfg.attn, "signed_sqrt", False), + scaled_attn =getattr(cfg.attn, "scaled_attn", False), + no_qk=getattr(cfg.attn, "no_qk", False), + ) + + + + self.O_h = nn.Linear(out_dim//num_heads * num_heads, out_dim) + if O_e: + self.O_e = nn.Linear(out_dim//num_heads * num_heads, out_dim) + else: + self.O_e = nn.Identity() + + # -------- Deg Scaler Option ------ + + if self.deg_scaler: + self.deg_coef = nn.Parameter(torch.zeros(1, out_dim//num_heads * num_heads, 2)) + nn.init.xavier_normal_(self.deg_coef) + + if self.layer_norm: + self.layer_norm1_h = nn.LayerNorm(out_dim) + self.layer_norm1_e = nn.LayerNorm(out_dim) if norm_e else nn.Identity() + + if self.batch_norm: + # when the batch_size is really small, use smaller momentum to avoid bad mini-batch leading to extremely bad val/test loss (NaN) + self.batch_norm1_h = nn.BatchNorm1d(out_dim, track_running_stats=not self.bn_no_runner, eps=1e-5, momentum=cfg.attn.bn_momentum) + self.batch_norm1_e = nn.BatchNorm1d(out_dim, track_running_stats=not self.bn_no_runner, eps=1e-5, momentum=cfg.attn.bn_momentum) if norm_e else nn.Identity() + + # FFN for h + self.FFN_h_layer1 = nn.Linear(out_dim, out_dim * 2) + self.FFN_h_layer2 = nn.Linear(out_dim * 2, out_dim) + + if self.layer_norm: + self.layer_norm2_h = nn.LayerNorm(out_dim) + + if self.batch_norm: + self.batch_norm2_h = nn.BatchNorm1d(out_dim, track_running_stats=not self.bn_no_runner, eps=1e-5, momentum=cfg.attn.bn_momentum) + + if self.rezero: + self.alpha1_h = nn.Parameter(torch.zeros(1,1)) + self.alpha2_h = nn.Parameter(torch.zeros(1,1)) + self.alpha1_e = nn.Parameter(torch.zeros(1,1)) + + def forward(self, batch): + h = batch.x + num_nodes = batch.num_nodes + log_deg = get_log_deg(batch) + + h_in1 = h # for first residual connection + e_in1 = batch.get("edge_attr", None) + e = None + # multi-head attention out + + h_attn_out, e_attn_out = self.attention(batch) + + h = h_attn_out.view(num_nodes, -1) + h = F.dropout(h, self.dropout, training=self.training) + + # degree scaler + if self.deg_scaler: + h = torch.stack([h, h * log_deg], dim=-1) + h = (h * self.deg_coef).sum(dim=-1) + + h = self.O_h(h) + if e_attn_out is not None: + e = e_attn_out.flatten(1) + e = F.dropout(e, self.dropout, training=self.training) + e = self.O_e(e) + + if self.residual: + if self.rezero: h = h * self.alpha1_h + h = h_in1 + h # residual connection + if e is not None: + if self.rezero: e = e * self.alpha1_e + e = e + e_in1 + + if self.layer_norm: + h = self.layer_norm1_h(h) + if e is not None: e = self.layer_norm1_e(e) + + if self.batch_norm: + h = self.batch_norm1_h(h) + if e is not None: e = self.batch_norm1_e(e) + + # FFN for h + h_in2 = h # for second residual connection + h = self.FFN_h_layer1(h) + h = self.act(h) + h = F.dropout(h, self.dropout, training=self.training) + h = self.FFN_h_layer2(h) + + if self.residual: + if self.rezero: h = h * self.alpha2_h + h = h_in2 + h # residual connection + + if self.layer_norm: + h = self.layer_norm2_h(h) + + if self.batch_norm: + h = self.batch_norm2_h(h) + + batch.x = h + if self.update_e: + batch.edge_attr = e + else: + batch.edge_attr = e_in1 + + return batch + + def __repr__(self): + return '{}(in_channels={}, out_channels={}, heads={}, residual={})\n[{}]'.format( + self.__class__.__name__, + self.in_channels, + self.out_channels, self.num_heads, self.residual, + super().__repr__(), + ) + + +@torch.no_grad() +def get_log_deg(batch): + if "log_deg" in batch: + log_deg = batch.log_deg + elif "deg" in batch: + deg = batch.deg + log_deg = torch.log(deg + 1).unsqueeze(-1) + else: + warnings.warn("Compute the degree on the fly; Might be problematric if have applied edge-padding to complete graphs") + deg = pyg.utils.degree(batch.edge_index[1], + num_nodes=batch.num_nodes, + dtype=torch.float + ) + log_deg = torch.log(deg + 1) + log_deg = log_deg.view(batch.num_nodes, 1) + return log_deg + + diff --git a/gridfm_graphkit/models/grit_transformer.py b/gridfm_graphkit/models/grit_transformer.py new file mode 100644 index 0000000..3422a21 --- /dev/null +++ b/gridfm_graphkit/models/grit_transformer.py @@ -0,0 +1,403 @@ +from gridfm_graphkit.io.registries import MODELS_REGISTRY +import torch +from torch import nn +from torch_geometric.data import Data + +from gridfm_graphkit.models.rrwp_encoder import RRWPLinearNodeEncoder, RRWPLinearEdgeEncoder +from gridfm_graphkit.models.grit_layer import GritTransformerLayer +from gridfm_graphkit.models.kernel_pos_encoder import RWSENodeEncoder +from torch_scatter import scatter_add +from gridfm_graphkit.datasets.globals import PG_H + + +class BatchNorm1dNode(torch.nn.Module): + r"""A batch normalization layer for node-level features. + + Args: + dim_in (int): BatchNorm input dimension. + eps (float): BatchNorm eps. + momentum (float): BatchNorm momentum. + """ + def __init__(self, dim_in, eps, momentum): + super().__init__() + self.bn = torch.nn.BatchNorm1d( + dim_in, + eps=eps, + momentum=momentum, + ) + + def forward(self, batch): + batch.x = self.bn(batch.x) + return batch + + +class BatchNorm1dEdge(torch.nn.Module): + r"""A batch normalization layer for edge-level features. + + Args: + dim_in (int): BatchNorm input dimension. + eps (float): BatchNorm eps. + momentum (float): BatchNorm momentum. + """ + def __init__(self, dim_in, eps, momentum): + super().__init__() + self.bn = torch.nn.BatchNorm1d( + dim_in, + eps=eps, + momentum=momentum, + ) + + def forward(self, batch): + batch.edge_attr = self.bn(batch.edge_attr) + return batch + + +class LinearNodeEncoder(torch.nn.Module): + def __init__(self, dim_in, emb_dim): + super().__init__() + + self.encoder = torch.nn.Linear(dim_in, emb_dim) + + def forward(self, batch): + batch.x = self.encoder(batch.x) + return batch + +class LinearEdgeEncoder(torch.nn.Module): + def __init__(self, edge_dim, emb_dim): + super().__init__() + + self.in_dim = edge_dim + + self.encoder = torch.nn.Linear(self.in_dim, emb_dim) + + def forward(self, batch): + batch.edge_attr = self.encoder(batch.edge_attr.view(-1, self.in_dim)) + return batch + + +class FeatureEncoder(torch.nn.Module): + """ + Encoding node and edge features + + Args: + dim_in (int): Input feature dimension + + """ + def __init__( + self, + dim_in, + dim_inner, + args + ): + super(FeatureEncoder, self).__init__() + self.dim_in = dim_in + if args.encoder.node_encoder: + # Encode integer node features via nn.Embeddings + if 'RWSE' in args.encoder.node_encoder_name: + self.node_encoder = RWSENodeEncoder(self.dim_in, dim_inner, args.encoder.posenc_RWSE) + else: + self.node_encoder = LinearNodeEncoder(self.dim_in, dim_inner) + if args.encoder.node_encoder_bn: + self.node_encoder_bn = BatchNorm1dNode(dim_inner, 1e-5, 0.1) + # Update dim_in to reflect the new dimension fo the node features + self.dim_in = dim_inner + if args.encoder.edge_encoder: + edge_dim = args.edge_dim + enc_dim_edge = dim_inner + # Encode integer edge features via nn.Embeddings + self.edge_encoder = LinearEdgeEncoder(edge_dim, enc_dim_edge) + if args.encoder.edge_encoder_bn: + self.edge_encoder_bn = BatchNorm1dEdge(enc_dim_edge, 1e-5, 0.1) + + def forward(self, batch): + for module in self.children(): + batch = module(batch) + return batch + +class GraphHead(nn.Module): + """ + Prediction head for decoding tasks. + Args: + dim_in (int): Input dimension. + dim_out (int): Output dimension. For binary prediction, dim_out=1. + L (int): Number of hidden layers. + """ + + def __init__(self, dim_in, dim_out): + super().__init__() + + self.FC_layers = nn.Sequential( + nn.Linear(dim_in, dim_in), + nn.LeakyReLU(), + nn.Linear(dim_in, dim_out), + ) + + def _apply_index(self, batch): + return batch.graph_feature, batch.y + + def forward(self, batch): + graph_emb = self.FC_layers(batch.x) + batch.graph_feature = graph_emb + pred, label = self._apply_index(batch) + return pred + + +class GritTransformer(torch.nn.Module): + """ + The GritTransformer (Graph Inductive Bias Transformer) from + Graph Inductive Biases in Transformers without Message Passing, L. Ma et al., + 2023. + + """ + def __init__(self, args, include_decoder=True): + super().__init__() + + + dim_in = args.model.input_dim + dim_out = args.model.output_dim + dim_inner = args.model.hidden_size + dim_edge = args.model.edge_dim + num_heads = args.model.attention_head + dropout = args.model.dropout + num_layers = args.model.num_layers + self.mask_dim = getattr(args.data, "mask_dim", 6) + self.mask_value = getattr(args.data, "mask_value", -1.0) + self.learn_mask = getattr(args.data, "learn_mask", False) + if self.learn_mask: + self.mask_value = nn.Parameter( + torch.randn(self.mask_dim) + self.mask_value, + requires_grad=True, + ) + else: + self.mask_value = nn.Parameter( + torch.zeros(self.mask_dim) + self.mask_value, + requires_grad=False, + ) + + self.encoder = FeatureEncoder( + dim_in, + dim_inner, + args.model + ) + dim_in = self.encoder.dim_in + + if args.data.posenc_RRWP.enable: + + self.rrwp_abs_encoder = RRWPLinearNodeEncoder( + args.data.posenc_RRWP.ksteps, + dim_inner + ) + rel_pe_dim = args.data.posenc_RRWP.ksteps + self.rrwp_rel_encoder = RRWPLinearEdgeEncoder( + rel_pe_dim, + dim_inner, + pad_to_full_graph=args.model.gt.attn.full_attn, + add_node_attr_as_self_loop=False, + fill_value=0. + ) + + assert args.model.hidden_size == dim_inner == dim_in, \ + "The inner and hidden dims must match." + + layers = [] + for ll in range(num_layers): + # The last layer's edge output is never consumed downstream + # (only node features feed into the output heads), so skip + # creating O_e / norm_e parameters to avoid DDP unused-parameter + # errors. + is_last = (ll == num_layers - 1) + layers.append(GritTransformerLayer( + in_dim=args.model.gt.dim_hidden, + out_dim=args.model.gt.dim_hidden, + num_heads=num_heads, + dropout=dropout, + act=args.model.act, + attn_dropout=args.model.gt.attn_dropout, + layer_norm=args.model.gt.layer_norm, + batch_norm=args.model.gt.batch_norm, + residual=True, + norm_e=False if is_last else args.model.gt.attn.norm_e, + O_e=False if is_last else args.model.gt.attn.O_e, + cfg=args.model.gt, + )) + + self.layers = nn.Sequential(*layers) + + if include_decoder: + self.decoder = GraphHead(dim_inner, dim_out) + + def forward(self, batch): + """ + Forward pass for GRIT. + + Args: + batch (Batch): Pytorch Geometric Batch object, with x, y encodings, etc. + + Returns: + output (Tensor): Output node features of shape [num_nodes, output_dim]. + """ + # print('xxxx',batch.x.min(), batch.x.max()) + # print('yyyyy',batch.y.min(), batch.y.max()) + # print('>>>>', batch) + for module in self.children(): + batch = module(batch) + + return batch + +def aggregate_pg(batch, mask_value=-1.0): + """Aggregate per-generator active power (PG) onto bus nodes. + + In the homogeneous reference, PG is a direct bus feature visible to the + transformer alongside Pd, Qd, Vm, Va, etc. In the heterogeneous + representation PG lives on separate generator nodes, so it must be + aggregated onto buses before the transformer can learn voltage-power + coupling. + + Masked generators (where PG has been replaced by the mask value) are + excluded from the sum to avoid corrupting the aggregated signal. Buses + where *all* connected generators are masked receive the mask value + instead, preserving a consistent "unknown" indicator. + """ + gen_to_bus = batch["gen", "connected_to", "bus"].edge_index + gen_pg = batch["gen"].x[:, PG_H] + gen_masked = batch.mask_dict["gen"][:, PG_H] # True = masked + + # Zero out masked generators so they don't contribute to the sum + pg_clean = torch.where(gen_masked, torch.zeros_like(gen_pg), gen_pg) + + pg_per_bus = scatter_add( + pg_clean, + gen_to_bus[1], + dim=0, + dim_size=batch["bus"].x.size(0), + ) + + # Check which buses have ALL generators masked (or no generators at all) + unmasked_count = scatter_add( + (~gen_masked).float(), + gen_to_bus[1], + dim=0, + dim_size=batch["bus"].x.size(0), + ) + all_masked = unmasked_count == 0 + + # Set mask_value for fully-masked buses + pg_per_bus[all_masked] = mask_value + + return pg_per_bus + + +@MODELS_REGISTRY.register("GRIT") +class GritHeteroAdapter(torch.nn.Module): + """Adapter that enables the homogeneous GRIT transformer to operate on + heterogeneous power-grid graphs. + + Extracts the bus-only homogeneous subgraph using PyG's native HeteroData + accessors, runs it through the GRIT encoder and transformer layers, and + produces per-node-type predictions. Generator output comes from a + lightweight standalone MLP (generators are not seen by the transformer). + + Returns: + dict: ``{"bus": Tensor[num_bus, output_bus_dim], + "gen": Tensor[num_gen, output_gen_dim]}`` + """ + + def __init__(self, args): + super().__init__() + + dim_inner = args.model.hidden_size + output_bus_dim = args.model.output_bus_dim + output_gen_dim = args.model.output_gen_dim + input_gen_dim = args.model.input_gen_dim + + # Ensure config keys expected by GritTransformer are present. + # input_dim = bus feature dimension (used by FeatureEncoder) + # output_dim = bus output dimension (used by the unused GraphHead) + if not hasattr(args.model, "input_dim"): + args.model.input_dim = args.model.input_bus_dim + if not hasattr(args.model, "output_dim"): + args.model.output_dim = output_bus_dim + + # Sync PE kernel.times from data config into model encoder config so + # users only need to specify it once (under data.posenc_RWSE). + if ( + hasattr(args.data, "posenc_RWSE") + and args.data.posenc_RWSE.enable + and hasattr(args.model, "encoder") + and hasattr(args.model.encoder, "posenc_RWSE") + ): + from gridfm_graphkit.io.param_handler import NestedNamespace + enc_rwse = args.model.encoder.posenc_RWSE + if not hasattr(enc_rwse, "kernel"): + enc_rwse.kernel = NestedNamespace() + enc_rwse.kernel.times = args.data.posenc_RWSE.kernel.times + + # Sync gt.dim_hidden from model.hidden_size so it is specified once. + if hasattr(args.model, "gt"): + args.model.gt.dim_hidden = args.model.hidden_size + + # The original homogeneous GRIT + # (encoder + optional PE encoders + transformer layers) + # Decoder is excluded — this adapter provides its own per-type heads. + self.grit = GritTransformer(args, include_decoder=False) + + # Per-node-type output heads (replace GraphHead for hetero output) + self.bus_head = nn.Sequential( + nn.Linear(dim_inner, dim_inner), + nn.LeakyReLU(), + nn.Linear(dim_inner, output_bus_dim), + ) + # gen_head is only needed for tasks that require per-generator + # predictions (e.g. OPF cost computation). When output_gen_dim is 0 + # or not set, skip it to avoid DDP unused-parameter errors. + if output_gen_dim and output_gen_dim > 0: + self.gen_head = nn.Sequential( + nn.Linear(input_gen_dim, dim_inner), + nn.LeakyReLU(), + nn.Linear(dim_inner, output_gen_dim), + ) + else: + self.gen_head = None + + def forward(self, batch): + """Forward pass on a heterogeneous power-grid batch. + + Args: + batch: A batched ``HeteroData`` with node types ``"bus"`` and + ``"gen"``, and edge type ``("bus", "connects", "bus")``. + + Returns: + dict with keys ``"bus"`` and ``"gen"``, each mapping to the + predicted output features. + """ + # --- Extract bus-only homogeneous subgraph --- + # Aggregate generator PG onto buses + pg_per_bus = aggregate_pg(batch, mask_value=self.grit.mask_value[0].item()) + bus_x = torch.cat([batch["bus"].x, pg_per_bus.unsqueeze(-1)], dim=-1) # 15 → 16D + + homo = Data( + x=bus_x, + y=batch["bus"].y, + edge_index=batch["bus", "connects", "bus"].edge_index, + edge_attr=batch["bus", "connects", "bus"].edge_attr, + batch=batch["bus"].batch, + ) + + # Forward positional-encoding attributes if present + for attr in ("pestat_RWSE", "rrwp", "rrwp_index", "rrwp_val", "log_deg", "deg"): + if hasattr(batch["bus"], attr): + setattr(homo, attr, getattr(batch["bus"], attr)) + + # --- Run GRIT encoder + PE encoders + transformer layers --- + homo = self.grit.encoder(homo) + if hasattr(self.grit, "rrwp_abs_encoder"): + homo = self.grit.rrwp_abs_encoder(homo) + if hasattr(self.grit, "rrwp_rel_encoder"): + homo = self.grit.rrwp_rel_encoder(homo) + homo = self.grit.layers(homo) + + # --- Per-type decoding --- + bus_out = self.bus_head(homo.x) + gen_out = self.gen_head(batch["gen"].x) if self.gen_head is not None else batch["gen"].x + + return {"bus": bus_out, "gen": gen_out} diff --git a/gridfm_graphkit/models/kernel_pos_encoder.py b/gridfm_graphkit/models/kernel_pos_encoder.py new file mode 100644 index 0000000..b24078d --- /dev/null +++ b/gridfm_graphkit/models/kernel_pos_encoder.py @@ -0,0 +1,84 @@ +import torch +import torch.nn as nn + + +class KernelPENodeEncoder(torch.nn.Module): + """Configurable kernel-based Positional Encoding node encoder. + + The choice of which kernel-based statistics to use is configurable through + setting of `kernel_type`. Based on this, the appropriate config is selected, + and also the appropriate variable with precomputed kernel stats is then + selected from PyG Data graphs in `forward` function. + E.g., supported are 'RWSE', 'HKdiagSE', 'ElstaticSE'. + + PE of size `dim_pe` will get appended to each node feature vector. + If `expand_x` set True, original node features will be first linearly + projected to (dim_emb - dim_pe) size and the concatenated with PE. + + Args: + dim_emb: Size of final node embedding + expand_x: Expand node features `x` from dim_in to (dim_emb - dim_pe) + """ + + kernel_type = None # Instantiated type of the KernelPE, e.g. RWSE + + def __init__(self, dim_in, dim_emb, pecfg, expand_x=True): + super().__init__() + if self.kernel_type is None: + raise ValueError(f"{self.__class__.__name__} has to be " + f"preconfigured by setting 'kernel_type' class" + f"variable before calling the constructor.") + + dim_pe = pecfg.pe_dim # Size of the kernel-based PE embedding + num_rw_steps = pecfg.kernel.times + norm_type = pecfg.raw_norm_type.lower() # Raw PE normalization layer type + # self.pass_as_var = pecfg.pass_as_var # Pass PE also as a separate variable + + if dim_emb - dim_pe < 1: + raise ValueError(f"PE dim size {dim_pe} is too large for " + f"desired embedding size of {dim_emb}.") + + if expand_x: + self.linear_x = nn.Linear(dim_in, dim_emb - dim_pe) + self.expand_x = expand_x + + if norm_type == 'batchnorm': + self.raw_norm = nn.BatchNorm1d(num_rw_steps) + else: + self.raw_norm = None + + self.pe_encoder = nn.Linear(num_rw_steps, dim_pe) + + + def forward(self, batch): + pestat_var = f"pestat_{self.kernel_type}" + if not hasattr(batch, pestat_var): + raise ValueError(f"Precomputed '{pestat_var}' variable is " + f"required for {self.__class__.__name__}; set " + f"config 'posenc_{self.kernel_type}.enable' to " + f"True, and also set 'posenc.kernel.times' values") + + pos_enc = getattr(batch, pestat_var) # (Num nodes) x (Num kernel times) + # pos_enc = batch.rw_landing # (Num nodes) x (Num kernel times) + if self.raw_norm: + pos_enc = self.raw_norm(pos_enc) + pos_enc = self.pe_encoder(pos_enc) # (Num nodes) x dim_pe + + # Expand node features if needed + if self.expand_x: + h = self.linear_x(batch.x) + else: + h = batch.x + # Concatenate final PEs to input embedding + batch.x = torch.cat((h, pos_enc), 1) + # Keep PE also separate in a variable (e.g. for skip connections to input) + # if self.pass_as_var: + # setattr(batch, f'pe_{self.kernel_type}', pos_enc) + + return batch + + +class RWSENodeEncoder(KernelPENodeEncoder): + """Random Walk Structural Encoding node encoder. + """ + kernel_type = 'RWSE' \ No newline at end of file diff --git a/gridfm_graphkit/models/rrwp_encoder.py b/gridfm_graphkit/models/rrwp_encoder.py new file mode 100644 index 0000000..1f7fd10 --- /dev/null +++ b/gridfm_graphkit/models/rrwp_encoder.py @@ -0,0 +1,183 @@ +""" + The RRWP encoder for GRIT (ours) +""" +import torch +from torch import nn +from torch.nn import functional as F +import torch_sparse + +from torch_geometric.utils import remove_self_loops, add_remaining_self_loops, add_self_loops +from torch_scatter import scatter +import warnings + +def full_edge_index(edge_index, batch=None): + """ + Return the Full batched sparse adjacency matrices given by edge indices. + Returns batched sparse adjacency matrices with exactly those edges that + are not in the input `edge_index` while ignoring self-loops. + Implementation inspired by `torch_geometric.utils.to_dense_adj` + Args: + edge_index: The edge indices. + batch: Batch vector, which assigns each node to a specific example. + Returns: + Complementary edge index. + """ + + if batch is None: + batch = edge_index.new_zeros(edge_index.max().item() + 1) + + batch_size = batch.max().item() + 1 + one = batch.new_ones(batch.size(0)) + num_nodes = scatter(one, batch, + dim=0, dim_size=batch_size, reduce='add') + cum_nodes = torch.cat([batch.new_zeros(1), num_nodes.cumsum(dim=0)]) + + negative_index_list = [] + for i in range(batch_size): + n = num_nodes[i].item() + size = [n, n] + adj = torch.ones(size, dtype=torch.short, + device=edge_index.device) + + adj = adj.view(size) + _edge_index = adj.nonzero(as_tuple=False).t().contiguous() + # _edge_index, _ = remove_self_loops(_edge_index) + negative_index_list.append(_edge_index + cum_nodes[i]) + + edge_index_full = torch.cat(negative_index_list, dim=1).contiguous() + return edge_index_full + + + +class RRWPLinearNodeEncoder(torch.nn.Module): + """ + FC_1(RRWP) + FC_2 (Node-attr) + note: FC_2 is given by the Typedict encoder of node-attr in some cases + Parameters: + num_classes - the number of classes for the embedding mapping to learn + """ + def __init__(self, emb_dim, out_dim, use_bias=False, batchnorm=False, layernorm=False, pe_name="rrwp"): + super().__init__() + self.batchnorm = batchnorm + self.layernorm = layernorm + self.name = pe_name + + self.fc = nn.Linear(emb_dim, out_dim, bias=use_bias) + torch.nn.init.xavier_uniform_(self.fc.weight) + + if self.batchnorm: + self.bn = nn.BatchNorm1d(out_dim) + if self.layernorm: + self.ln = nn.LayerNorm(out_dim) + + def forward(self, batch): + # Encode just the first dimension if more exist + rrwp = batch[f"{self.name}"] + rrwp = self.fc(rrwp) + + if self.batchnorm: + rrwp = self.bn(rrwp) + + if self.layernorm: + rrwp = self.ln(rrwp) + + if "x" in batch: + batch.x = batch.x + rrwp + else: + batch.x = rrwp + + return batch + + +class RRWPLinearEdgeEncoder(torch.nn.Module): + """ + Merge RRWP with given edge-attr and Zero-padding to all pairs of node + FC_1(RRWP) + FC_2(edge-attr) + - FC_2 given by the TypedictEncoder in same cases + - Zero-padding for non-existing edges in fully-connected graph + - (optional) add node-attr as the E_{i,i}'s attr + note: assuming node-attr and edge-attr is with the same dimension after Encoders + """ + def __init__(self, emb_dim, out_dim, batchnorm=False, layernorm=False, use_bias=False, + pad_to_full_graph=True, fill_value=0., + add_node_attr_as_self_loop=False, + overwrite_old_attr=False): + super().__init__() + # note: batchnorm/layernorm might ruin some properties of pe on providing shortest-path distance info + self.emb_dim = emb_dim + self.out_dim = out_dim + self.add_node_attr_as_self_loop = add_node_attr_as_self_loop + self.overwrite_old_attr=overwrite_old_attr # remove the old edge-attr + + self.batchnorm = batchnorm + self.layernorm = layernorm + if self.batchnorm or self.layernorm: + warnings.warn("batchnorm/layernorm might ruin some properties of pe on providing shortest-path distance info ") + + # print('--------fc in and out:', emb_dim, out_dim) + self.fc = nn.Linear(emb_dim, out_dim, bias=use_bias) + torch.nn.init.xavier_uniform_(self.fc.weight) + self.pad_to_full_graph = pad_to_full_graph + self.fill_value = 0. + + padding = torch.ones(1, out_dim, dtype=torch.float) * fill_value + self.register_buffer("padding", padding) + + if self.batchnorm: + self.bn = nn.BatchNorm1d(out_dim) + + if self.layernorm: + self.ln = nn.LayerNorm(out_dim) + + def forward(self, batch): + rrwp_idx = batch.rrwp_index + rrwp_val = batch.rrwp_val + edge_index = batch.edge_index + edge_attr = batch.edge_attr + rrwp_val = self.fc(rrwp_val) + + if edge_attr is None: + edge_attr = edge_index.new_zeros(edge_index.size(1), rrwp_val.size(1)) + # zero padding for non-existing edges + + if self.overwrite_old_attr: + out_idx, out_val = rrwp_idx, rrwp_val + else: + edge_index, edge_attr = add_self_loops(edge_index, edge_attr, num_nodes=batch.num_nodes, fill_value=0.) + out_idx, out_val = torch_sparse.coalesce( + torch.cat([edge_index, rrwp_idx], dim=1), + torch.cat([edge_attr, rrwp_val], dim=0), + batch.num_nodes, batch.num_nodes, + op="add" + ) + + + if self.pad_to_full_graph: + edge_index_full = full_edge_index(out_idx, batch=batch.batch) + edge_attr_pad = self.padding.repeat(edge_index_full.size(1), 1) + # zero padding to fully-connected graphs + out_idx = torch.cat([out_idx, edge_index_full], dim=1) + out_val = torch.cat([out_val, edge_attr_pad], dim=0) + out_idx, out_val = torch_sparse.coalesce( + out_idx, out_val, batch.num_nodes, batch.num_nodes, + op="add" + ) + + if self.batchnorm: + out_val = self.bn(out_val) + + if self.layernorm: + out_val = self.ln(out_val) + + + batch.edge_index, batch.edge_attr = out_idx, out_val + return batch + + def __repr__(self): + return f"{self.__class__.__name__}" \ + f"(pad_to_full_graph={self.pad_to_full_graph}," \ + f"fill_value={self.fill_value}," \ + f"{self.fc.__repr__()})" + + + diff --git a/gridfm_graphkit/tasks/compute_ac_dc_metrics.py b/gridfm_graphkit/tasks/compute_ac_dc_metrics.py index 8dcfc8c..3d8118c 100644 --- a/gridfm_graphkit/tasks/compute_ac_dc_metrics.py +++ b/gridfm_graphkit/tasks/compute_ac_dc_metrics.py @@ -4,10 +4,6 @@ import os import numpy as np import pandas as pd -from gridfm_datakit.utils.power_balance import ( - compute_branch_powers_vectorized, - compute_bus_balance, -) N_SCENARIO_PER_PARTITION = 200 NUM_PROCESSES = 64 @@ -132,6 +128,11 @@ def compute_ac_dc_metrics( bus_df, branch_df, runtime_df = _load_test_data(data_dir, test_ids) + from gridfm_datakit.utils.power_balance import ( + compute_branch_powers_vectorized, + compute_bus_balance, + ) + # ========================= # AC residuals # ========================= diff --git a/gridfm_graphkit/tasks/opf_task.py b/gridfm_graphkit/tasks/opf_task.py index b28c5a0..e2f0bcd 100644 --- a/gridfm_graphkit/tasks/opf_task.py +++ b/gridfm_graphkit/tasks/opf_task.py @@ -233,18 +233,23 @@ def test_step(self, batch, batch_idx, dataloader_idx=0): loss_dict["Active Power Loss"] = final_residual_real_bus.detach() loss_dict["Reactive Power Loss"] = final_residual_imag_bus.detach() + # Slice output to the 4 target columns [VM, VA, PG, QG] so that + # models with wider bus output (e.g. GRIT with output_bus_dim=6) + # are compared correctly against the 4-column target. + output_bus_metrics = output["bus"][:, [VM_OUT, VA_OUT, PG_OUT, QG_OUT]] + mse_PQ = F.mse_loss( - output["bus"][mask_PQ], + output_bus_metrics[mask_PQ], target[mask_PQ], reduction="none", ) mse_PV = F.mse_loss( - output["bus"][mask_PV], + output_bus_metrics[mask_PV], target[mask_PV], reduction="none", ) mse_REF = F.mse_loss( - output["bus"][mask_REF], + output_bus_metrics[mask_REF], target[mask_REF], reduction="none", ) diff --git a/gridfm_graphkit/tasks/pf_task.py b/gridfm_graphkit/tasks/pf_task.py index cdc9d64..46be699 100644 --- a/gridfm_graphkit/tasks/pf_task.py +++ b/gridfm_graphkit/tasks/pf_task.py @@ -5,6 +5,8 @@ QG_H, VM_H, VA_H, + # Generator feature indices + PG_H, # Output feature indices VM_OUT, VA_OUT, @@ -34,11 +36,67 @@ import pandas as pd +def _build_bus_target(batch, num_bus): + """Build a 4-column bus-level target tensor [VM, VA, PG_agg, QG]. + + Generator PG is aggregated onto buses via scatter_add so that the + target layout matches the bus head output columns. + """ + _, gen_to_bus_index = batch.edge_index_dict[("gen", "connected_to", "bus")] + agg_gen_on_bus = scatter_add( + batch.y_dict["gen"], + gen_to_bus_index, + dim=0, + dim_size=num_bus, + ) + target = torch.stack( + [ + batch.y_dict["bus"][:, VM_H], + batch.y_dict["bus"][:, VA_H], + agg_gen_on_bus.squeeze(), + batch.y_dict["bus"][:, QG_H], + ], + dim=1, + ) + return target, gen_to_bus_index, agg_gen_on_bus + + +def _clamp_known_to_ground_truth(output_bus, target, batch, gen_to_bus_index, num_bus): + """Replace predicted values with ground truth for known (unmasked) quantities. + + During both training (PBELoss) and evaluation, the model is only + responsible for predicting masked unknowns. Known quantities (e.g. + VM at PV buses, VA at REF, PG at non-slack generators) are clamped to + ground truth so that prediction errors on non-target outputs do not + pollute the power-balance residual. + """ + mask_bus = batch.mask_dict["bus"] + eval_bus = output_bus.clone() + eval_bus[:, VM_OUT] = torch.where( + mask_bus[:, VM_H], output_bus[:, VM_OUT], target[:, VM_OUT], + ) + eval_bus[:, VA_OUT] = torch.where( + mask_bus[:, VA_H], output_bus[:, VA_OUT], target[:, VA_OUT], + ) + gen_pg_masked = batch.mask_dict["gen"][:, PG_H].float() + any_gen_masked = ( + scatter_add(gen_pg_masked, gen_to_bus_index, dim=0, dim_size=num_bus) > 0 + ) + eval_bus[:, PG_OUT] = torch.where( + any_gen_masked, output_bus[:, PG_OUT], target[:, PG_OUT], + ) + eval_bus[:, QG_OUT] = torch.where( + mask_bus[:, QG_H], output_bus[:, QG_OUT], target[:, QG_OUT], + ) + return eval_bus + + @TASK_REGISTRY.register("PowerFlow") class PowerFlowTask(ReconstructionTask): """ - Concrete Optimal Power Flow task. - Extends ReconstructionTask and adds OPF-specific metrics. + Concrete Power Flow task. + Extends ReconstructionTask and adds PF-specific evaluation metrics + (power balance residuals, per-bus-type RMSE). """ def __init__(self, args, data_normalizers): @@ -58,34 +116,18 @@ def test_step(self, batch, batch_idx, dataloader_idx=0): num_bus = batch.x_dict["bus"].size(0) bus_edge_index = batch.edge_index_dict[("bus", "connects", "bus")] bus_edge_attr = batch.edge_attr_dict[("bus", "connects", "bus")] - _, gen_to_bus_index = batch.edge_index_dict[("gen", "connected_to", "bus")] - agg_gen_on_bus = scatter_add( - batch.y_dict["gen"], - gen_to_bus_index, - dim=0, - dim_size=num_bus, - ) - # output_agg = torch.cat([batch.y_dict["bus"], agg_gen_on_bus], dim=1) - target = torch.stack( - [ - batch.y_dict["bus"][:, VM_H], - batch.y_dict["bus"][:, VA_H], - agg_gen_on_bus.squeeze(), - batch.y_dict["bus"][:, QG_H], - ], - dim=1, + target, gen_to_bus_index, agg_gen_on_bus = _build_bus_target(batch, num_bus) + eval_bus = _clamp_known_to_ground_truth( + output["bus"], target, batch, gen_to_bus_index, num_bus, ) - # UN-COMMENT THIS TO CHECK PBE ON GROUND TRUTH - # output["bus"] = target - - Pft, Qft = branch_flow_layer(output["bus"], bus_edge_index, bus_edge_attr) + Pft, Qft = branch_flow_layer(eval_bus, bus_edge_index, bus_edge_attr) P_in, Q_in = node_injection_layer(Pft, Qft, bus_edge_index, num_bus) residual_P, residual_Q = node_residuals_layer( P_in, Q_in, - output["bus"], + eval_bus, batch.x_dict["bus"], ) @@ -166,18 +208,23 @@ def test_step(self, batch, batch_idx, dataloader_idx=0): loss_dict["PBE Mean"] = pbe_mean.detach() + # Slice output to the 4 target columns [VM, VA, PG, QG] so that + # models with wider bus output (e.g. GRIT with output_bus_dim=6) + # are compared correctly against the 4-column target. + output_bus_metrics = output["bus"][:, [VM_OUT, VA_OUT, PG_OUT, QG_OUT]] + mse_PQ = F.mse_loss( - output["bus"][mask_PQ], + output_bus_metrics[mask_PQ], target[mask_PQ], reduction="none", ) mse_PV = F.mse_loss( - output["bus"][mask_PV], + output_bus_metrics[mask_PV], target[mask_PV], reduction="none", ) mse_REF = F.mse_loss( - output["bus"][mask_REF], + output_bus_metrics[mask_REF], target[mask_REF], reduction="none", ) @@ -364,12 +411,17 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): bus_edge_index = batch.edge_index_dict[("bus", "connects", "bus")] bus_edge_attr = batch.edge_attr_dict[("bus", "connects", "bus")] - Pft, Qft = branch_flow_layer(output["bus"], bus_edge_index, bus_edge_attr) + target, gen_to_bus_index, agg_gen_on_bus = _build_bus_target(batch, num_bus) + eval_bus = _clamp_known_to_ground_truth( + output["bus"], target, batch, gen_to_bus_index, num_bus, + ) + + Pft, Qft = branch_flow_layer(eval_bus, bus_edge_index, bus_edge_attr) P_in, Q_in = node_injection_layer(Pft, Qft, bus_edge_index, num_bus) residual_P, residual_Q = node_residuals_layer( P_in, Q_in, - output["bus"], + eval_bus, batch.x_dict["bus"], ) residual_P = torch.abs(residual_P) @@ -391,14 +443,6 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): mask_PV = batch.mask_dict["PV"] mask_REF = batch.mask_dict["REF"] - _, gen_to_bus_index = batch.edge_index_dict[("gen", "connected_to", "bus")] - agg_gen_on_bus = scatter_add( - batch.y_dict["gen"], - gen_to_bus_index, - dim=0, - dim_size=num_bus, - ) - return { "scenario": scenario_ids.cpu().numpy(), "bus": local_bus_idx.cpu().numpy(), diff --git a/gridfm_graphkit/tasks/reconstruction_tasks.py b/gridfm_graphkit/tasks/reconstruction_tasks.py index 8742646..43e243c 100644 --- a/gridfm_graphkit/tasks/reconstruction_tasks.py +++ b/gridfm_graphkit/tasks/reconstruction_tasks.py @@ -39,16 +39,11 @@ def __init__(self, args, data_normalizers): self.batch_size = int(args.training.batch_size) self.test_outputs = {i: [] for i in range(len(args.data.networks))} - def forward(self, x_dict, edge_index_dict, edge_attr_dict, mask_dict): - return self.model(x_dict, edge_index_dict, edge_attr_dict, mask_dict) + def forward(self, batch): + return self.model(batch) def shared_step(self, batch): - output = self.forward( - x_dict=batch.x_dict, - edge_index_dict=batch.edge_index_dict, - edge_attr_dict=batch.edge_attr_dict, - mask_dict=batch.mask_dict, - ) + output = self.forward(batch) loss_dict = self.loss_fn( output, diff --git a/gridfm_graphkit/training/loss.py b/gridfm_graphkit/training/loss.py index d253d2b..d700b89 100644 --- a/gridfm_graphkit/training/loss.py +++ b/gridfm_graphkit/training/loss.py @@ -4,6 +4,7 @@ from abc import ABC, abstractmethod from gridfm_graphkit.io.registries import LOSS_REGISTRY from torch_scatter import scatter_add +from torch_geometric.utils import to_torch_coo_tensor from gridfm_graphkit.datasets.globals import ( # Bus feature indices @@ -17,8 +18,15 @@ VA_OUT, QG_OUT, PG_OUT, + PD_OUT, + QD_OUT, # Generator feature indices PG_H, + # Edge feature indices + YFF_TT_R, + YFF_TT_I, + YFT_TF_R, + YFT_TF_I, ) @@ -92,9 +100,12 @@ def forward( mask_dict, model=None, ): + gen_pred = pred_dict["gen"][:, : (PG_H + 1)] + gen_target = target_dict["gen"][:, : (PG_H + 1)] + mask = mask_dict["gen"][:, : (PG_H + 1)] loss = F.mse_loss( - pred_dict["gen"][mask_dict["gen"][:, : (PG_H + 1)]], - target_dict["gen"][mask_dict["gen"][:, : (PG_H + 1)]], + gen_pred[mask], + gen_target[mask], reduction=self.reduction, ) return {"loss": loss, "Masked generator MSE loss": loss.detach()} @@ -136,6 +147,81 @@ def forward( return {"loss": loss, "Masked bus MSE loss": loss.detach()} +@LOSS_REGISTRY.register("MaskedReconstructionMSE") +class MaskedReconstructionMSE(BaseLoss): + """Unified masked MSE over bus-level quantities [VM, VA, PG, QG, PD, QD]. + + Mirrors the homogeneous reference MaskedMSE by combining bus predictions + and aggregated generator PG into a single prediction/target/mask tensor. + PG targets are aggregated from generator ground truth onto buses via + scatter_add; the bus-level PG mask is True when any generator at the bus + is masked, indicating that the model must reconstruct that quantity. + + Replaces the separate MaskedBusMSE + MaskedGenMSE pair. + Requires output_bus_dim >= 6 so the bus head predicts + [VM, VA, PG, QG, PD, QD]. + """ + + def __init__(self, loss_args, args): + super().__init__() + self.reduction = "mean" + + def forward( + self, + pred_dict, + target_dict, + edge_index_dict, + edge_attr_dict, + mask_dict, + model=None, + ): + pred_bus = pred_dict["bus"] + target_bus = target_dict["bus"] + num_bus = target_bus.size(0) + gen_to_bus_ei = edge_index_dict[("gen", "connected_to", "bus")] + + # --- Build target: [VM, VA, PG_agg, QG, PD, QD] --- + target_pg_agg = scatter_add( + target_dict["gen"][:, PG_H], + gen_to_bus_ei[1], + dim=0, + dim_size=num_bus, + ) + target = torch.stack([ + target_bus[:, VM_H], + target_bus[:, VA_H], + target_pg_agg, + target_bus[:, QG_H], + target_bus[:, PD_H], + target_bus[:, QD_H], + ], dim=1) + + # --- Build mask: [N_bus, 6] --- + # PG bus-level mask: True if any generator at the bus has PG masked + gen_pg_masked = mask_dict["gen"][:, PG_H].float() + any_gen_masked = scatter_add( + gen_pg_masked, + gen_to_bus_ei[1], + dim=0, + dim_size=num_bus, + ) > 0 + + mask = torch.stack([ + mask_dict["bus"][:, VM_H], + mask_dict["bus"][:, VA_H], + any_gen_masked, + mask_dict["bus"][:, QG_H], + mask_dict["bus"][:, PD_H], + mask_dict["bus"][:, QD_H], + ], dim=1) + + # --- Prediction: [VM, VA, PG, QG, PD, QD] from bus head --- + pred = pred_bus[:, [VM_OUT, VA_OUT, PG_OUT, QG_OUT, PD_OUT, QD_OUT]] + + loss = F.mse_loss(pred[mask], target[mask], reduction=self.reduction) + return {"loss": loss, "Masked reconstruction MSE loss": loss.detach()} + + @LOSS_REGISTRY.register("MSE") class MSELoss(BaseLoss): """Standard Mean Squared Error loss.""" @@ -322,3 +408,174 @@ def forward( f"MSE loss {self.dim}": mse_loss.detach(), f"MAE loss {self.dim}": mae_loss.detach(), } + +@LOSS_REGISTRY.register("PBE") +class PBELoss(BaseLoss): + """ + Loss based on the Power Balance Equations. + + Adapted for the heterogeneous graph convention: predictions and targets + are passed as dicts (``{"bus": …, "gen": …}``). Generator active power + is aggregated onto bus nodes via the ``(gen, connected_to, bus)`` edge + index before computing the power balance. + """ + + def __init__(self, loss_args, args): + super(PBELoss, self).__init__() + self.visualization = getattr(loss_args, "visualization", False) + + def forward( + self, + pred_dict, + target_dict, + edge_index_dict, + edge_attr_dict, + mask_dict, + model=None, + ): + pred_bus = pred_dict["bus"] # [N_bus, output_bus_dim] + target_bus = target_dict["bus"] # [N_bus, bus_feat_dim] + num_bus = target_bus.size(0) + + bus_edge_index = edge_index_dict[("bus", "connects", "bus")] + bus_edge_attr = edge_attr_dict[("bus", "connects", "bus")] + mask_bus = mask_dict["bus"] + + # --- Clamp known values to ground truth --- + # In power flow, certain variables are "known" (unmasked) at each + # bus type (e.g. VM at PV buses, VA at REF). The model only needs + # to predict *masked* unknowns; for everything else we substitute + # the ground truth so that errors in non-target outputs do not + # pollute the physics loss. This matches the reference's + # ``temp_pred[unmasked] = target[unmasked]`` convention. + + Vm_pred = pred_bus[:, VM_OUT] + Va_pred = pred_bus[:, VA_OUT] + Vm_target = target_bus[:, VM_H] + Va_target = target_bus[:, VA_H] + + mask_Vm = mask_bus[:, VM_H] + mask_Va = mask_bus[:, VA_H] + + V_m = torch.where(mask_Vm, Vm_pred, Vm_target) + V_a = torch.where(mask_Va, Va_pred, Va_target) + + # Complex voltage + V = V_m * torch.exp(1j * V_a) + V_conj = torch.conj(V) + + # --- Admittance matrix from bus-bus edge attrs --- + # The Y-bus matrix has off-diagonal AND diagonal entries. + # + # Off-diagonal: Y[from][to] = Yft, Y[to][from] = Ytf, stored in the + # YFT_TF columns of the edge attributes. + # + # Diagonal: Y[k][k] = sum of Yff/Ytt for all branches at bus k. + # The dataset stores Yff (forward edges) and Ytt (reverse edges) in + # the YFF_TT columns. For every edge, YFF_TT at the *source* bus + # gives that branch's diagonal contribution at the source. Summing + # over all edges with source == k yields the full branch-diagonal. + # + # The reference project loads a pre-built Y-bus (y_bus_data.parquet) + # that includes self-loops for diagonal entries. Here we reconstruct + # the same structure from per-branch pi-model parameters. + + # Off-diagonal admittance values + edge_offdiag = bus_edge_attr[:, YFT_TF_R] + 1j * bus_edge_attr[:, YFT_TF_I] + + # Diagonal: aggregate Yff/Ytt to source bus of each edge + Y_diag_r = scatter_add( + bus_edge_attr[:, YFF_TT_R], + bus_edge_index[0], + dim=0, + dim_size=num_bus, + ) + Y_diag_i = scatter_add( + bus_edge_attr[:, YFF_TT_I], + bus_edge_index[0], + dim=0, + dim_size=num_bus, + ) + Y_diag = Y_diag_r + 1j * Y_diag_i + + # Build complete Y-bus: off-diagonal edges + self-loops for diagonal + diag_idx = torch.arange(num_bus, device=bus_edge_index.device) + full_edge_index = torch.cat( + [bus_edge_index, torch.stack([diag_idx, diag_idx])], dim=1, + ) + full_edge_values = torch.cat([edge_offdiag, Y_diag]) + + Y_bus_sparse = to_torch_coo_tensor( + full_edge_index, + full_edge_values, + size=(num_bus, num_bus), + ) + Y_bus_conj = torch.conj(Y_bus_sparse) + + # Complex power injection: S_inj = diag(V) * conj(Y) * conj(V) + S_injection = torch.diag(V) @ Y_bus_conj @ V_conj + + # --- Net power from predictions/targets --- + # Pg: use bus head prediction where masked, ground truth where known. + # Ground truth is aggregated from generator targets onto buses. + gen_to_bus_ei = edge_index_dict[("gen", "connected_to", "bus")] + target_pg_agg = scatter_add( + target_dict["gen"][:, PG_H], + gen_to_bus_ei[1], + dim=0, + dim_size=num_bus, + ) + gen_pg_masked = mask_dict["gen"][:, PG_H].float() + any_gen_masked = scatter_add( + gen_pg_masked, + gen_to_bus_ei[1], + dim=0, + dim_size=num_bus, + ) > 0 + Pg_per_bus = torch.where(any_gen_masked, pred_bus[:, PG_OUT], target_pg_agg) + + # Pd, Qd, Qg: same clamp-to-ground-truth logic. The size guard + # (``pred_bus.size(1) > *_OUT``) handles models with a narrow bus + # head (e.g. output_bus_dim=4) that don't predict PD/QD/QG; in that + # case the target is always used. + if pred_bus.size(1) > PD_OUT: + Pd = torch.where(mask_bus[:, PD_H], pred_bus[:, PD_OUT], target_bus[:, PD_H]) + else: + Pd = target_bus[:, PD_H] + if pred_bus.size(1) > QD_OUT: + Qd = torch.where(mask_bus[:, QD_H], pred_bus[:, QD_OUT], target_bus[:, QD_H]) + else: + Qd = target_bus[:, QD_H] + if pred_bus.size(1) > QG_OUT: + Qg = torch.where(mask_bus[:, QG_H], pred_bus[:, QG_OUT], target_bus[:, QG_H]) + else: + Qg = target_bus[:, QG_H] + + net_P = Pg_per_bus - Pd + net_Q = Qg - Qd + S_net = net_P + 1j * net_Q + + # --- Loss --- + loss = torch.mean(torch.abs(S_net - S_injection)) + + real_loss = torch.mean( + torch.abs(torch.real(S_net - S_injection)), + ) + imag_loss = torch.mean( + torch.abs(torch.imag(S_net - S_injection)), + ) + + result = { + "loss": loss, + "Power loss in p.u.": loss.detach(), + "Active Power Loss in p.u.": real_loss.detach(), + "Reactive Power Loss in p.u.": imag_loss.detach(), + } + if self.visualization: + result["Nodal Active Power Loss in p.u."] = torch.abs( + torch.real(S_net - S_injection), + ) + result["Nodal Reactive Power Loss in p.u."] = torch.abs( + torch.imag(S_net - S_injection), + ) + return result diff --git a/pyproject.toml b/pyproject.toml index 2b6c523..100b875 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,11 +44,14 @@ dependencies = [ "nbformat>=5.10.4", "networkx>=3.4.2", "numpy>=2.2.6", + "opt-einsum>=3.3.0", "pandas>=2.3.0", "plotly>=6.1.2", "pyyaml>=6.0.2", "torch>=2.7.1,<2.9", "torch-geometric>=2.6.1", + "torch-scatter>=2.1.2", + "torch-sparse>=0.6.18", "torchaudio>=2.7.1", "torchvision>=0.22.1", "lightning", diff --git a/scripts/benchmark_model_inference.py b/scripts/benchmark_model_inference.py new file mode 100644 index 0000000..fe3010f --- /dev/null +++ b/scripts/benchmark_model_inference.py @@ -0,0 +1,525 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +A unified script for benchmarking and limited custom profiling. Benchmarking columns in the output csv are [batch_size,avg_time_per_sample_ms]. + +Supports two model types via --model flag: + - "hetero" (default): GNS_heterogeneous with HeteroData (bus + gen nodes) + - "grit": GritHeteroAdapter with HeteroData (bus + gen nodes, optional PE attrs) + +Example usage — Heterogeneous GNS (edge count is 2*E (branch count)): + +###################################### + +CONF_PATH=../examples/config +OUT_DIR=../scripts +mkdir $OUT_DIR + +python benchmark_model_inference.py --model hetero --config $CONF_PATH/HGNS_PF_datakit_case30.yaml --num_nodes 30 --num_edges 82 --num_gens 6 --iterations 20 --output_csv $OUT_DIR/case30.csv || true +python benchmark_model_inference.py --model hetero --config $CONF_PATH/HGNS_PF_datakit_case118.yaml --num_nodes 118 --num_edges 372 --num_gens 54 --iterations 20 --output_csv $OUT_DIR/case118.csv || true + +###################################### + +Example usage — GRIT (HeteroData with PE, --num_gens required): + +###################################### + +python benchmark_model_inference.py --model grit --config $CONF_PATH/GRIT_PF_datakit_case14.yaml --num_nodes 30 --num_edges 82 --num_gens 6 --iterations 20 --output_csv $OUT_DIR/grit_case30.csv || true +python benchmark_model_inference.py --model grit --config $CONF_PATH/GRIT_PF_datakit_case14.yaml --num_nodes 118 --num_edges 372 --num_gens 54 --iterations 20 --output_csv $OUT_DIR/grit_case118.csv || true + +###################################### + +Author(s): Mangaliso M. - mngomezulum@ibm.com + Matteo M. - Not Available +""" + +import os +import time +import csv +import yaml +import torch +import argparse +import platform +from datetime import datetime +from torch_geometric.loader import DataLoader +from torch_geometric.data import HeteroData +from gridfm_graphkit.io.param_handler import NestedNamespace, load_model + +# Optional: tqdm (imported but not required for core flow) +try: + from tqdm import tqdm # noqa: F401 +except Exception: + pass + +# Compilation (kept from original) +import torch._dynamo as dynamo +dynamo.config.suppress_errors = False + +# ---------------------------- +# Argument Parsing +# ---------------------------- +parser = argparse.ArgumentParser(description="Benchmark GNN Model inference with profiling CSV") +parser.add_argument("--model", type=str, choices=["hetero", "grit"], default="hetero", + help="Model type: 'hetero' for GNS_heterogeneous, 'grit' for GritTransformer") +parser.add_argument("--config", type=str, required=True, help="Path to config YAML for model") +parser.add_argument("--num_nodes", type=int, required=True) +parser.add_argument("--num_gens", type=int, default=0, + help="Number of generator nodes (required for hetero, ignored for grit)") +parser.add_argument("--num_edges", type=int, required=True) +parser.add_argument("--output_csv", type=str, required=True) +parser.add_argument("--iterations", type=int, default=20) +parser.add_argument("--num_workers", type=int, default=0, help="DataLoader num_workers") +parser.add_argument("--pin_memory", action="store_true", help="Enable pin_memory in DataLoader when CUDA is available") +args = parser.parse_args() + +# --- Custom logging (ensure directory exists) +import logging +os.makedirs('logs', exist_ok=True) +logger = logging.getLogger('ibm_benchmark_logger') +logger.setLevel(logging.DEBUG) +logger.propagate = False +file_handler = logging.FileHandler('logs/ibm_bench_logs.log', mode='a') # 'a' for append, 'w' to overwrite +file_handler.setLevel(logging.INFO) +formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +file_handler.setFormatter(formatter) +if not logger.handlers: + logger.addHandler(file_handler) + +# ---------------------------- +# Load Model +# ---------------------------- +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +with open(args.config, "r") as f: + base_config = yaml.safe_load(f) + +config_args = NestedNamespace(**base_config) +model = load_model(config_args).to(device).eval() +tot_params = sum(p.numel() for p in model.parameters() if p.requires_grad) +print("**Total model trainable params: {}".format(tot_params)) + +# ---------------------------- +# Parameters +# ---------------------------- +MODEL_TYPE = args.model +N_BUS = args.num_nodes +E = args.num_edges + +# Default num_gens when not provided (shell script omits --num_gens) +N_GEN = args.num_gens if args.num_gens > 0 else max(1, N_BUS // 5) + +EDGE_FEATS = getattr(config_args.model, "edge_dim", 10) + +# Both model types use HeteroData with bus + gen nodes. +# Fall back to input_dim / defaults for configs that lack the hetero keys. +BUS_FEATS = getattr(config_args.model, "input_bus_dim", + getattr(config_args.model, "input_dim", 15)) +GEN_FEATS = getattr(config_args.model, "input_gen_dim", 6) + +if MODEL_TYPE == "grit": + # Positional encoding config (only GRIT uses these) + # Read enablement and dimensions from data config (canonical source). + RRWP_ENABLED = getattr(config_args.data.posenc_RRWP, "enable", False) if hasattr(config_args.data, "posenc_RRWP") else False + RRWP_KSTEPS = getattr(config_args.data.posenc_RRWP, "ksteps", 21) if RRWP_ENABLED else 0 + RWSE_ENABLED = hasattr(config_args.data, "posenc_RWSE") and getattr(config_args.data.posenc_RWSE, "enable", False) + RWSE_TIMES = getattr(config_args.data.posenc_RWSE.kernel, "times", 21) if RWSE_ENABLED else 0 +else: + RRWP_ENABLED = False + RRWP_KSTEPS = 0 + RWSE_ENABLED = False + RWSE_TIMES = 0 + +# Keep original batch sizes list +batch_sizes = [1, 2, 4, 8, 16, 32] +iterations = args.iterations + +# ---------------------------- +# Helpers +# ---------------------------- +def now_ms() -> float: + return time.perf_counter() * 1000.0 + +def maybe_cuda_sync(): + if torch.cuda.is_available(): + torch.cuda.synchronize() + +def get_env_info(): + # CPU name detection + cpu_name = None + try: + cpu_name = platform.processor() or None + if not cpu_name and os.path.exists("/proc/cpuinfo"): + with open("/proc/cpuinfo", "r") as f: + for line in f: + if "model name" in line: + cpu_name = line.strip().split(":", 1)[1].strip() + break + if not cpu_name: + cpu_name = platform.uname().machine + except Exception: + cpu_name = "unknown" + + # GPU names and device info + if torch.cuda.is_available(): + try: + gpu_names_list = [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())] + gpu_names = "; ".join(gpu_names_list) + except Exception: + gpu_names = "cuda_available_but_name_unreadable" + device_type = "cuda" + device_name = torch.cuda.get_device_name(0) if torch.cuda.device_count() > 0 else "cuda" + cuda_version_in_torch = torch.version.cuda + cudnn_version = torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else None + else: + # Apple Metal backend? + if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available(): + device_type = "mps" + device_name = "Apple MPS" + gpu_names = "mps" + cuda_version_in_torch = None + cudnn_version = None + else: + device_type = "cpu" + device_name = "cpu" + gpu_names = "none" + cuda_version_in_torch = None + cudnn_version = None + + info = { + "device_type": device_type, + "device_name": device_name, + "gpu_names": gpu_names, + "cpu_name": cpu_name, + "torch_version": torch.__version__, + "cuda_version_in_torch": cuda_version_in_torch, + "cudnn_version": cudnn_version, + "python_version": platform.python_version(), + } + return info + +# ---------------------------- +# Generate Synthetic Hetero Graph +# ---------------------------- +def generate_hetero_graph(): + """ + Generates a dummy heterogeneous power network graph for benchmarking. + + Returns: + data (HeteroData): single self-contained heterogeneous graph with: + - data["bus"].x, data["gen"].x + - edge_index & edge_attr for all relations + - mask_dict inside data.mask_dict + """ + data = HeteroData() + + # Node features + data["bus"].x = torch.randn(N_BUS, BUS_FEATS) + data["gen"].x = torch.randn(N_GEN, GEN_FEATS) + + # Dummy targets + data["bus"].y = torch.randn(N_BUS, BUS_FEATS) + data["gen"].y = torch.randn(N_GEN, GEN_FEATS) + + # Edges: Bus–Bus + src = torch.randint(0, N_BUS, (E,)) + dst = torch.randint(0, N_BUS, (E,)) + data["bus", "connects", "bus"].edge_index = torch.stack([src, dst], dim=0) + data["bus", "connects", "bus"].edge_attr = torch.randn(E, EDGE_FEATS) + + # Edges: Gen–Bus & Bus–Gen + gen_to_bus = torch.randint(0, N_BUS, (N_GEN,)) + + # Gen → Bus + data["gen", "connected_to", "bus"].edge_index = torch.stack( + [torch.arange(N_GEN), gen_to_bus], dim=0 + ) + + # Bus → Gen + data["bus", "connected_to", "gen"].edge_index = torch.stack( + [gen_to_bus, torch.arange(N_GEN)], dim=0 + ) + + # No edge features for these + data["gen", "connected_to", "bus"].edge_attr = None + data["bus", "connected_to", "gen"].edge_attr = None + + # Dummy masks (all True) + mask_bus = torch.ones_like(data["bus"].x, dtype=torch.bool) + mask_gen = torch.ones_like(data["gen"].x, dtype=torch.bool) + bus_types = torch.randint(0, 3, (N_BUS,)) + mask_branch = torch.ones_like(data["bus", "connects", "bus"].edge_attr, dtype=torch.bool) + + mask_PQ = bus_types == 0 + mask_PV = bus_types == 1 + mask_REF = bus_types == 2 + + data.mask_dict = { + "bus": mask_bus, + "gen": mask_gen, + "PQ": mask_PQ, + "PV": mask_PV, + "REF": mask_REF, + "branch": mask_branch + } + return data + + +# ---------------------------- +# Generate Synthetic Homogeneous Graph (GRIT) +# ---------------------------- +def generate_grit_graph(): + """ + Generates a dummy heterogeneous graph for GRIT benchmarking. + + GritHeteroAdapter expects a HeteroData batch, so we generate the same + structure as generate_hetero_graph() but also attach PE attributes on + the bus node store when RWSE / RRWP are enabled. + + Returns: + data (HeteroData): heterogeneous graph with bus & gen nodes, + plus optional PE attributes on data["bus"]. + """ + data = generate_hetero_graph() + + # RWSE positional encoding on bus nodes + if RWSE_ENABLED: + data["bus"].pestat_RWSE = torch.randn(N_BUS, RWSE_TIMES).abs() + + # RRWP positional / structural encoding on bus nodes + if RRWP_ENABLED: + data["bus"].rrwp = torch.randn(N_BUS, RRWP_KSTEPS) + # Sparse RRWP for edges: include existing bus-bus edges + self-loops + bb_ei = data["bus", "connects", "bus"].edge_index + self_loops = torch.arange(N_BUS).unsqueeze(0).repeat(2, 1) + rrwp_idx = torch.cat([bb_ei, self_loops], dim=1) + rrwp_nnz = rrwp_idx.size(1) + data["bus"].rrwp_index = rrwp_idx + data["bus"].rrwp_val = torch.randn(rrwp_nnz, RRWP_KSTEPS) + + return data + +# ---------------------------- +# Benchmark Function +# ---------------------------- +def benchmark(): + # Environment/context info (constant per run) + env = get_env_info() + timestamp = datetime.now().isoformat(timespec='seconds') + + # Measure synthetic graph creation + t0 = now_ms() + if MODEL_TYPE == "hetero": + data = generate_hetero_graph() + else: + data = generate_grit_graph() + t1 = now_ms() + data_gen_time_ms = t1 - t0 + + # Move the base graph to device (preserve original behavior) + maybe_cuda_sync() + t2 = now_ms() + data = data.to(device) + maybe_cuda_sync() + t3 = now_ms() + graph_to_device_time_ms = t3 - t2 + + batch_sizes_used = [] + times = [] + + header = [ + # Keep original first two columns + "batch_size", + "avg_time_per_sample_ms", + + # Execution config + "num_iters", + "total_samples", + + # Data/IO timing + "data_gen_time_ms", + "graph_to_device_time_ms", + "clone_list_time_ms", + "dataloader_create_time_ms", + "dataloader_first_iter_time_ms", + "batch_to_device_time_ms", + + # Model timing + "warmup_time_ms", + "iter_total_wall_time_ms", + "iter_gpu_time_ms", + "gpu_idle_time_ms", + "gpu_busy_ratio", + "samples_per_sec_wall", + "samples_per_sec_gpu", + "timing_source", # "cuda_event" or "wall_clock" + + # Memory + "max_cuda_mem_alloc_bytes", + "max_cuda_mem_reserved_bytes", + + # Graph & model context + "n_bus", "n_gen", "n_edges", + "bus_feats", "gen_feats", "edge_feats", + + # Runtime context + "device_type", "device_name", + "torch_version", "cuda_version_in_torch", "cudnn_version", + "python_version", + "cpu_name", # NEW + "gpu_names", # NEW + "timestamp_iso", + "num_workers", + "pin_memory", + ] + + with open(args.output_csv, "w", newline="") as csvfile: + writer = csv.writer(csvfile) + writer.writerow(header) + + for batch_size in batch_sizes: + # Build list of graphs (on device, preserving original flow) + maybe_cuda_sync() + t_clone_start = now_ms() + data_list = [data.clone() for _ in range(batch_size)] + maybe_cuda_sync() + t_clone_end = now_ms() + clone_list_time_ms = t_clone_end - t_clone_start + + # Create DataLoader + pin_mem = args.pin_memory and torch.cuda.is_available() + persistent = args.num_workers > 0 + t_dl_create_start = now_ms() + loader = DataLoader( + data_list, + batch_size=batch_size, + num_workers=args.num_workers, + pin_memory=pin_mem, + persistent_workers=persistent, + ) + t_dl_create_end = now_ms() + dataloader_create_time_ms = t_dl_create_end - t_dl_create_start + + # Fetch first batch (collate) + t_iter_start = now_ms() + batch = next(iter(loader)) + t_iter_end = now_ms() + dataloader_first_iter_time_ms = t_iter_end - t_iter_start + + # Ensure batch on device (likely ~0 if items already on device) + maybe_cuda_sync() + t_b2d_start = now_ms() + batch = batch.to(device, non_blocking=True) if torch.cuda.is_available() else batch.to(device) + maybe_cuda_sync() + t_b2d_end = now_ms() + batch_to_device_time_ms = t_b2d_end - t_b2d_start + + test_model = model + + # Warmup (excluded from main timing) + maybe_cuda_sync() + t_warmup_start = now_ms() + with torch.no_grad(): + for _ in range(5): + _ = test_model(batch.clone()) + maybe_cuda_sync() + t_warmup_end = now_ms() + warmup_time_ms = t_warmup_end - t_warmup_start + + num_iters = iterations + total_samples = batch_size * num_iters + + # Reset CUDA memory stats and set up timing + if torch.cuda.is_available(): + torch.cuda.reset_peak_memory_stats(device) + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # Iteration timing + maybe_cuda_sync() + wall_start = now_ms() + with torch.no_grad(): + if torch.cuda.is_available(): + start_event.record() + for _ in range(num_iters): + _ = test_model(batch.clone()) + if torch.cuda.is_available(): + end_event.record() + maybe_cuda_sync() + wall_end = now_ms() + + iter_total_wall_time_ms = wall_end - wall_start + + if torch.cuda.is_available(): + iter_gpu_time_ms = float(start_event.elapsed_time(end_event)) # ms + timing_source = "cuda_event" + avg_time_per_sample_ms = iter_gpu_time_ms / total_samples + gpu_idle_time_ms = max(iter_total_wall_time_ms - iter_gpu_time_ms, 0.0) + gpu_busy_ratio = (iter_gpu_time_ms / iter_total_wall_time_ms) if iter_total_wall_time_ms > 0 else None + max_cuda_mem_alloc_bytes = int(torch.cuda.max_memory_allocated(device)) + max_cuda_mem_reserved_bytes = int(torch.cuda.max_memory_reserved(device)) + samples_per_sec_gpu = (total_samples / (iter_gpu_time_ms / 1000.0)) if iter_gpu_time_ms > 0 else None + else: + iter_gpu_time_ms = None + timing_source = "wall_clock" + avg_time_per_sample_ms = iter_total_wall_time_ms / total_samples + gpu_idle_time_ms = None + gpu_busy_ratio = None + max_cuda_mem_alloc_bytes = None + max_cuda_mem_reserved_bytes = None + samples_per_sec_gpu = None + + samples_per_sec_wall = (total_samples / (iter_total_wall_time_ms / 1000.0)) if iter_total_wall_time_ms > 0 else None + + # Prepare row + row = [ + batch_size, + avg_time_per_sample_ms, + + num_iters, + total_samples, + + data_gen_time_ms, + graph_to_device_time_ms, + clone_list_time_ms, + dataloader_create_time_ms, + dataloader_first_iter_time_ms, + batch_to_device_time_ms, + + warmup_time_ms, + iter_total_wall_time_ms, + iter_gpu_time_ms, + gpu_idle_time_ms, + gpu_busy_ratio, + samples_per_sec_wall, + samples_per_sec_gpu, + timing_source, + + max_cuda_mem_alloc_bytes, + max_cuda_mem_reserved_bytes, + + N_BUS, N_GEN, E, + BUS_FEATS, GEN_FEATS, EDGE_FEATS, + + env["device_type"], env["device_name"], + env["torch_version"], env["cuda_version_in_torch"], env["cudnn_version"], + env["python_version"], + env["cpu_name"], + env["gpu_names"], + timestamp, + args.num_workers, + bool(pin_mem), + ] + + writer.writerow(row) + csvfile.flush() + batch_sizes_used.append(batch_size) + times.append(avg_time_per_sample_ms) + + return batch_sizes_used, times + + +if __name__ == "__main__": + print(f"Starting benchmark for {os.path.basename(args.output_csv)} ..") + benchmark() + print(f"Finished benchmarking for {os.path.basename(args.output_csv)}\n ...") diff --git a/scripts/run_benchmark.sh b/scripts/run_benchmark.sh new file mode 100755 index 0000000..2054f30 --- /dev/null +++ b/scripts/run_benchmark.sh @@ -0,0 +1,38 @@ +#!/bin/bash + +set +e # Do NOT exit on error + +CONFIGS=( + "grit01" +) + +CONFIG_PATHS=( + "../examples/config/GRIT_PF_datakit_case14.yaml" +) + +GRAPH_SIZES=( + "30 110" + "300 1120" + "2000 9276" + "3022 11390" + "9241 41337" + "30000 100784" +) + +OUTPUT_DIR="benchmark_results" +mkdir -p $OUTPUT_DIR +for i in "${!CONFIGS[@]}"; do + config_name="${CONFIGS[$i]}" + config_path="${CONFIG_PATHS[$i]}" + for size in "${GRAPH_SIZES[@]}"; do + read -r nodes edges <<< "$size" + output_file="${OUTPUT_DIR}/${config_name}_${nodes}nodes_${edges}edges.csv" + echo "Running benchmark for $config_name with $nodes nodes and $edges edges..." + python benchmark_model_inference.py \ + --model "grit" \ + --config "$config_path" \ + --output_csv "$output_file" \ + --num_nodes "$nodes" \ + --num_edges "$edges" || echo "Failed for $config_name with $nodes nodes" + done +done