Skip to content

fix(normalizers): cast baseMVA to float32 to support MPS backend#47

Open
emmanuelbadmus wants to merge 1 commit intogridfm:mainfrom
emmanuelbadmus:fix/mps-float64-support
Open

fix(normalizers): cast baseMVA to float32 to support MPS backend#47
emmanuelbadmus wants to merge 1 commit intogridfm:mainfrom
emmanuelbadmus:fix/mps-float64-support

Conversation

@emmanuelbadmus
Copy link
Copy Markdown

The MPS (Apple Silicon) backend for PyTorch does not support float64. This PR ensures baseMVA is cast to float32 in the normalizer to prevent crashes on Mac GPUs.

Copilot AI review requested due to automatic review settings April 1, 2026 20:40
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR aims to prevent PyTorch MPS (Apple Silicon GPU) crashes by avoiding float64 usage for the baseMVA scaling factor during normalization.

Changes:

  • Updates HeteroDataMVANormalizer.transform() to store data.baseMVA as a torch.float32 tensor instead of a Python scalar.
Comments suppressed due to low confidence (1)

gridfm_graphkit/datasets/normalizers.py:244

  • Casting data.baseMVA to float32 here can make inverse_transform() fail the exact equality check (data.baseMVA != self.baseMVA).any() due to float32 rounding vs self.baseMVA (often a higher-precision Python/NumPy float), causing a false mismatch error. Either keep data.baseMVA in the same precision as self.baseMVA, or cast/compare using a tolerance. Also note that the MPS float64 crash is more directly caused by self.baseMVA being a np.float64 from np.percentile(...) (see fit()); casting self.baseMVA to a Python float/float32 before the division ops would address dtype promotion at the source.
        data.baseMVA = torch.tensor(self.baseMVA, dtype=torch.float32)
        data.is_normalized = True

    def inverse_transform(self, data: HeteroData):
        if self.baseMVA is None or self.baseMVA == 0:
            raise ValueError("BaseMVA not properly set")

        if not data.is_normalized.all():
            raise ValueError("Attempting to denormalize data which is not normalized")

        if (data.baseMVA != self.baseMVA).any():
            raise ValueError(
                f"Normalizer baseMVA was {self.baseMVA} but Data object baseMVA is {data.baseMVA}",
            )

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 230 to 232
data.edge_attr_dict[("bus", "connects", "bus")][:, RATE_A] /= self.baseMVA
data.baseMVA = self.baseMVA
data.baseMVA = torch.tensor(self.baseMVA, dtype=torch.float32)
data.is_normalized = True
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

data.baseMVA is changed from a Python scalar to a torch.Tensor. This breaks LoadGridParamsFromPath.forward() (gridfm_graphkit/datasets/transforms.py:114), which assigns data.baseMVA into HeteroDataMVANormalizer.baseMVA; HeteroDataMVANormalizer.transform() then does if self.baseMVA is None or self.baseMVA == 0: and will error when self.baseMVA is a tensor. Consider keeping data.baseMVA as a Python float (or ensuring callers use .item()), and instead cast the scalar used in normalization ops to a float32 tensor on the right device/dtype.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants