fix(normalizers): cast baseMVA to float32 to support MPS backend#47
fix(normalizers): cast baseMVA to float32 to support MPS backend#47emmanuelbadmus wants to merge 1 commit intogridfm:mainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This PR aims to prevent PyTorch MPS (Apple Silicon GPU) crashes by avoiding float64 usage for the baseMVA scaling factor during normalization.
Changes:
- Updates
HeteroDataMVANormalizer.transform()to storedata.baseMVAas atorch.float32tensor instead of a Python scalar.
Comments suppressed due to low confidence (1)
gridfm_graphkit/datasets/normalizers.py:244
- Casting
data.baseMVAto float32 here can makeinverse_transform()fail the exact equality check(data.baseMVA != self.baseMVA).any()due to float32 rounding vsself.baseMVA(often a higher-precision Python/NumPy float), causing a false mismatch error. Either keepdata.baseMVAin the same precision asself.baseMVA, or cast/compare using a tolerance. Also note that the MPS float64 crash is more directly caused byself.baseMVAbeing anp.float64fromnp.percentile(...)(seefit()); castingself.baseMVAto a Python float/float32 before the division ops would address dtype promotion at the source.
data.baseMVA = torch.tensor(self.baseMVA, dtype=torch.float32)
data.is_normalized = True
def inverse_transform(self, data: HeteroData):
if self.baseMVA is None or self.baseMVA == 0:
raise ValueError("BaseMVA not properly set")
if not data.is_normalized.all():
raise ValueError("Attempting to denormalize data which is not normalized")
if (data.baseMVA != self.baseMVA).any():
raise ValueError(
f"Normalizer baseMVA was {self.baseMVA} but Data object baseMVA is {data.baseMVA}",
)
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| data.edge_attr_dict[("bus", "connects", "bus")][:, RATE_A] /= self.baseMVA | ||
| data.baseMVA = self.baseMVA | ||
| data.baseMVA = torch.tensor(self.baseMVA, dtype=torch.float32) | ||
| data.is_normalized = True |
There was a problem hiding this comment.
data.baseMVA is changed from a Python scalar to a torch.Tensor. This breaks LoadGridParamsFromPath.forward() (gridfm_graphkit/datasets/transforms.py:114), which assigns data.baseMVA into HeteroDataMVANormalizer.baseMVA; HeteroDataMVANormalizer.transform() then does if self.baseMVA is None or self.baseMVA == 0: and will error when self.baseMVA is a tensor. Consider keeping data.baseMVA as a Python float (or ensuring callers use .item()), and instead cast the scalar used in normalization ops to a float32 tensor on the right device/dtype.
The MPS (Apple Silicon) backend for PyTorch does not support float64. This PR ensures baseMVA is cast to float32 in the normalizer to prevent crashes on Mac GPUs.