diff --git a/gridfm_graphkit/datasets/normalizers.py b/gridfm_graphkit/datasets/normalizers.py index 11601a6..2276a2f 100644 --- a/gridfm_graphkit/datasets/normalizers.py +++ b/gridfm_graphkit/datasets/normalizers.py @@ -228,7 +228,7 @@ def transform(self, data: HeteroData): data.edge_attr_dict[("bus", "connects", "bus")][:, ANG_MIN] *= torch.pi / 180.0 data.edge_attr_dict[("bus", "connects", "bus")][:, ANG_MAX] *= torch.pi / 180.0 data.edge_attr_dict[("bus", "connects", "bus")][:, RATE_A] /= self.baseMVA - data.baseMVA = self.baseMVA + data.baseMVA = torch.tensor(self.baseMVA, dtype=torch.float32) data.is_normalized = True def inverse_transform(self, data: HeteroData):