diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index 01b8b42..68ee525 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -49,6 +49,8 @@ jobs: run: | python -m pip install --upgrade pip wheel pip install -e ".[test]" + TORCH_VERSION=$(python -c "import torch; print(torch.__version__.split('+')[0])") + pip install torch-scatter -f "https://data.pyg.org/whl/torch-${TORCH_VERSION}+cpu.html" - name: Unit tests run: | @@ -67,11 +69,11 @@ jobs: python3 -m pip install --upgrade pip python3 -m pip install --upgrade "git+https://github.com/ibm/detect-secrets.git@master#egg=detect-secrets" python3 -m pip install boxsdk - + - name: Scan repository & write snapshot run: | mkdir -p security-outputs - + # Run detect-secrets while skipping binary files detect-secrets scan \ --exclude-files '.*\.ipynb$|.*\.(png|jpg|jpeg|gif|pdf|onnx|pt|pth|bin|zip)$' \ @@ -166,22 +168,30 @@ jobs: pip install -e .[dev,test] || pip install -e . - name: Run pip-audit uses: pypa/gh-action-pip-audit@v1.1.0 + with: + # CVE-2026-4539: pygments AdlLexer ReDoS, local-only attack vector, no fix released yet + ignore-vulns: CVE-2026-4539 trivy_repo: - name: Trivy (repo scan) - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - name: Run Trivy filesystem scan - uses: aquasecurity/trivy-action@0.33.1 - with: - scan-type: 'fs' - scan-ref: '.' - format: 'sarif' - output: 'trivy-results.sarif' - severity: 'HIGH,CRITICAL' - ignore-unfixed: true - - name: Upload SARIF to Code Scanning - uses: github/codeql-action/upload-sarif@v3 - with: - sarif_file: trivy-results.sarif + name: Trivy (repo scan) + runs-on: ubuntu-latest + permissions: + security-events: write + steps: + - uses: actions/checkout@v4 + + - name: Run Trivy vulnerability scanner in repo mode + # We use the official container-based action to avoid binary install issues + uses: aquasecurity/trivy-action@master + with: + scan-type: 'fs' + ignore-unfixed: true + format: 'sarif' + output: 'trivy-results.sarif' + severity: 'HIGH,CRITICAL' + + - name: Upload SARIF to Code Scanning + uses: github/codeql-action/upload-sarif@v3 + if: always() # Upload results even if vulnerabilities are found + with: + sarif_file: 'trivy-results.sarif' diff --git a/docs/datasets/data_modules.md b/docs/datasets/data_modules.md index bf47118..a5e4dff 100644 --- a/docs/datasets/data_modules.md +++ b/docs/datasets/data_modules.md @@ -1,3 +1,3 @@ -# LitGridDataModule +# LitGridHeteroDataModule -::: gridfm_graphkit.datasets.powergrid_datamodule.LitGridDataModule +::: gridfm_graphkit.datasets.hetero_powergrid_datamodule.LitGridHeteroDataModule diff --git a/docs/datasets/data_normalization.md b/docs/datasets/data_normalization.md index f41334d..1747fd1 100644 --- a/docs/datasets/data_normalization.md +++ b/docs/datasets/data_normalization.md @@ -3,12 +3,10 @@ Normalization improves neural network training by ensuring features are well-scaled, preventing issues like exploding gradients and slow convergence. In power grids, where variables like voltage and power span wide ranges, normalization is essential. -The `gridfm-graphkit` package offers four methods: +The `gridfm-graphkit` package offers normalization methods based on the per-unit (p.u.) system: -- [`Min-Max Normalization`](#minmaxnormalizer) -- [`Standardization (Z-score)`](#standardizer) -- [`Identity (no normalization)`](#identitynormalizer) -- [`BaseMVA Normalization`](#basemvanormalizer) +- [`BaseMVA Normalization`](#heterodatamvanormalizer) +- [`Per-Sample BaseMVA Normalization`](#heterodatapersamplemvanormalizer) Each of these strategies implements a unified interface and can be used interchangeably depending on the learning task and data characteristics. @@ -25,27 +23,15 @@ Each of these strategies implements a unified interface and can be used intercha --- -### `MinMaxNormalizer` +### `HeteroDataMVANormalizer` -::: gridfm_graphkit.datasets.normalizers.MinMaxNormalizer +::: gridfm_graphkit.datasets.normalizers.HeteroDataMVANormalizer --- -### `Standardizer` +### `HeteroDataPerSampleMVANormalizer` -::: gridfm_graphkit.datasets.normalizers.Standardizer - ---- - -### `BaseMVANormalizer` - -::: gridfm_graphkit.datasets.normalizers.BaseMVANormalizer - ---- - -### `IdentityNormalizer` - -::: gridfm_graphkit.datasets.normalizers.IdentityNormalizer +::: gridfm_graphkit.datasets.normalizers.HeteroDataPerSampleMVANormalizer --- @@ -54,13 +40,18 @@ Each of these strategies implements a unified interface and can be used intercha Example: ```python -from gridfm_graphkit.datasets.normalizers import MinMaxNormalizer -import torch +from gridfm_graphkit.datasets.normalizers import HeteroDataMVANormalizer +from torch_geometric.data import HeteroData + +# Create normalizer +normalizer = HeteroDataMVANormalizer(args) + +# Fit on training data +params = normalizer.fit(data_path, scenario_ids) -data = torch.randn(100, 5) # Example tensor +# Transform data +normalizer.transform(hetero_data) -normalizer = MinMaxNormalizer(node_data=True,args=None) -params = normalizer.fit(data) -normalized = normalizer.transform(data) -restored = normalizer.inverse_transform(normalized) +# Inverse transform to restore original scale +normalizer.inverse_transform(hetero_data) ``` diff --git a/docs/datasets/powergrid.md b/docs/datasets/powergrid.md index 45476ac..1f983a5 100644 --- a/docs/datasets/powergrid.md +++ b/docs/datasets/powergrid.md @@ -1,3 +1,3 @@ -## `GridDatasetDisk` +## `HeteroGridDatasetDisk` -::: gridfm_graphkit.datasets.powergrid_dataset.GridDatasetDisk +::: gridfm_graphkit.datasets.powergrid_hetero_dataset.HeteroGridDatasetDisk diff --git a/docs/datasets/transforms.md b/docs/datasets/transforms.md index dd7f66d..0dcf981 100644 --- a/docs/datasets/transforms.md +++ b/docs/datasets/transforms.md @@ -2,26 +2,18 @@ > Each transformation class inherits from [`BaseTransform`](https://pytorch-geometric.readthedocs.io/en/latest/modules/transforms.html#torch_geometric.transforms.BaseTransform) provided by [PyTorch Geometric](https://pytorch-geometric.readthedocs.io/). -### `AddNormalizedRandomWalkPE` +### `RemoveInactiveGenerators` -::: gridfm_graphkit.datasets.transforms.AddNormalizedRandomWalkPE +::: gridfm_graphkit.datasets.transforms.RemoveInactiveGenerators -### `AddEdgeWeights` +### `RemoveInactiveBranches` -::: gridfm_graphkit.datasets.transforms.AddEdgeWeights +::: gridfm_graphkit.datasets.transforms.RemoveInactiveBranches -### `AddIdentityMask` +### `ApplyMasking` -::: gridfm_graphkit.datasets.transforms.AddIdentityMask +::: gridfm_graphkit.datasets.transforms.ApplyMasking -### `AddRandomMask` +### `LoadGridParamsFromPath` -::: gridfm_graphkit.datasets.transforms.AddRandomMask - -### `AddPFMask` - -::: gridfm_graphkit.datasets.transforms.AddPFMask - -### `AddOPFMask` - -::: gridfm_graphkit.datasets.transforms.AddOPFMask +::: gridfm_graphkit.datasets.transforms.LoadGridParamsFromPath diff --git a/docs/install/installation.md b/docs/install/installation.md index 07dc502..89ee4be 100644 --- a/docs/install/installation.md +++ b/docs/install/installation.md @@ -1,14 +1,18 @@ +# Installation + You can install `gridfm-graphkit` directly from PyPI: ```bash pip install gridfm-graphkit ``` +For GPU support and compatibility with PyTorch Geometric's scatter operations, install PyTorch (and optionally CUDA) first, then install the matching `torch-scatter` wheel. See [PyTorch and torch-scatter](#pytorch-and-torch-scatter-optional) below. + --- ## Development Setup -To contribute or develop locally, clone the repository and install in editable mode: +To contribute or develop locally, clone the repository and install in editable mode. Use Python 3.10, 3.11, or 3.12 (3.12 is recommended). ```bash git clone git@github.com:gridfm/gridfm-graphkit.git @@ -18,6 +22,26 @@ source venv/bin/activate pip install -e . ``` +### PyTorch and torch-scatter (optional) + +If you need GPU acceleration or PyTorch Geometric scatter ops (used by the library), install PyTorch and the matching `torch-scatter` wheel: + +1. Install PyTorch (see [pytorch.org](https://pytorch.org/) for your platform and CUDA version). + +2. Get your Torch + CUDA version string: + ```bash + TORCH_CUDA_VERSION=$(python -c "import torch; print(torch.__version__ + ('+cpu' if torch.version.cuda is None else ''))") + ``` + +3. Install the correct `torch-scatter` wheel: + ```bash + pip install torch-scatter -f https://data.pyg.org/whl/torch-${TORCH_CUDA_VERSION}.html + ``` + +--- + +## Optional extras + For documentation generation and unit testing, install with the optional `dev` and `test` extras: ```bash diff --git a/docs/models/models.md b/docs/models/models.md index 9e822ca..7c8c5c6 100644 --- a/docs/models/models.md +++ b/docs/models/models.md @@ -1,10 +1,37 @@ # Models -### `GPSTransformer` +### `GNS_heterogeneous` -::: gridfm_graphkit.models.gps_transformer.GPSTransformer +::: gridfm_graphkit.models.gnn_heterogeneous_gns.GNS_heterogeneous +--- -### `GNN_TransformerConv` +## Physics Decoders -::: gridfm_graphkit.models.gnn_transformer.GNN_TransformerConv +### `PhysicsDecoderOPF` + +::: gridfm_graphkit.models.utils.PhysicsDecoderOPF + +### `PhysicsDecoderPF` + +::: gridfm_graphkit.models.utils.PhysicsDecoderPF + +### `PhysicsDecoderSE` + +::: gridfm_graphkit.models.utils.PhysicsDecoderSE + +--- + +## Utility Modules + +### `ComputeBranchFlow` + +::: gridfm_graphkit.models.utils.ComputeBranchFlow + +### `ComputeNodeInjection` + +::: gridfm_graphkit.models.utils.ComputeNodeInjection + +### `ComputeNodeResiduals` + +::: gridfm_graphkit.models.utils.ComputeNodeResiduals diff --git a/docs/tasks/base_task.md b/docs/tasks/base_task.md new file mode 100644 index 0000000..1153a2d --- /dev/null +++ b/docs/tasks/base_task.md @@ -0,0 +1,216 @@ +# Base Task + +The `BaseTask` class is an abstract base class that provides the foundation for all task implementations in GridFM-GraphKit. It extends PyTorch Lightning's `LightningModule` and defines the common interface and shared functionality for training, validation, and testing. + +## Overview + +`BaseTask` serves as the parent class for all task-specific implementations, providing: + +- **Abstract method definitions**: Enforces implementation of core methods in subclasses +- **Optimizer configuration**: Sets up AdamW optimizer with learning rate scheduling +- **Normalization statistics logging**: Saves normalization parameters for reproducibility +- **Hyperparameter management**: Automatically saves hyperparameters for experiment tracking + +## BaseTask Class + +::: gridfm_graphkit.tasks.base_task.BaseTask + options: + show_root_heading: true + show_source: true + members: + - __init__ + - forward + - training_step + - validation_step + - test_step + - predict_step + - on_fit_start + - configure_optimizers + +## Methods + +### `__init__(args, data_normalizers)` + +Initialize the base task with configuration and normalizers. + +**Parameters:** + +- `args` (NestedNamespace): Experiment configuration containing all hyperparameters +- `data_normalizers` (list): List of normalizer objects, one per dataset + +**Attributes Set:** + +- `self.args`: Stores the configuration +- `self.data_normalizers`: Stores the normalizers +- Automatically calls `save_hyperparameters()` for experiment tracking + +--- + +### `forward(*args, **kwargs)` (Abstract) + +Defines the forward pass through the model. Must be implemented by subclasses. + +**Returns:** + +- Model output (structure depends on task implementation) + +--- + +### `training_step(batch)` (Abstract) + +Executes one training step. Must be implemented by subclasses. + +**Parameters:** + +- `batch`: A batch of data from the training dataloader + +**Returns:** + +- Loss tensor for backpropagation + +--- + +### `validation_step(batch, batch_idx)` (Abstract) + +Executes one validation step. Must be implemented by subclasses. + +**Parameters:** + +- `batch`: A batch of data from the validation dataloader +- `batch_idx` (int): Index of the current batch + +**Returns:** + +- Loss tensor or metrics dictionary + +--- + +### `test_step(batch, batch_idx, dataloader_idx=0)` (Abstract) + +Executes one test step. Must be implemented by subclasses. + +**Parameters:** + +- `batch`: A batch of data from the test dataloader +- `batch_idx` (int): Index of the current batch +- `dataloader_idx` (int): Index of the dataloader (for multiple test datasets) + +**Returns:** + +- Metrics dictionary or None + +--- + +### `predict_step(batch, batch_idx, dataloader_idx=0)` (Abstract) + +Executes one prediction step. Must be implemented by subclasses. + +**Parameters:** + +- `batch`: A batch of data from the prediction dataloader +- `batch_idx` (int): Index of the current batch +- `dataloader_idx` (int): Index of the dataloader + +**Returns:** + +- Predictions dictionary + +--- + +### `on_fit_start()` + +Called at the beginning of training. Saves normalization statistics to disk. + +**Behavior:** + +- Creates a `stats` directory in the logging directory +- Saves human-readable normalization statistics to `normalization_stats.txt` +- Saves machine-loadable statistics to `normalizer_stats.pt` (PyTorch format) +- Only executes on rank 0 in distributed training (via `@rank_zero_only` decorator) + +**Output Files:** + +1. **`normalization_stats.txt`**: Human-readable text file with statistics for each dataset +2. **`normalizer_stats.pt`**: PyTorch file containing a dictionary keyed by network name + +--- + +### `configure_optimizers()` + +Configures the optimizer and learning rate scheduler. + +**Optimizer:** + +- **Type**: AdamW +- **Learning Rate**: From `args.optimizer.learning_rate` +- **Betas**: From `args.optimizer.beta1` and `args.optimizer.beta2` + +**Scheduler:** + +- **Type**: ReduceLROnPlateau +- **Mode**: Minimize +- **Factor**: From `args.optimizer.lr_decay` +- **Patience**: From `args.optimizer.lr_patience` +- **Monitored Metric**: "Validation loss" + +**Returns:** + +- Dictionary with optimizer and lr_scheduler configuration + +## Usage + +`BaseTask` is not used directly. Instead, create a subclass that implements all abstract methods: + +```python +from gridfm_graphkit.tasks.base_task import BaseTask + +class MyCustomTask(BaseTask): + def __init__(self, args, data_normalizers): + super().__init__(args, data_normalizers) + # Initialize task-specific components + + def forward(self, x_dict, edge_index_dict, edge_attr_dict, mask_dict): + # Implement forward pass + pass + + def training_step(self, batch): + # Implement training logic + pass + + def validation_step(self, batch, batch_idx): + # Implement validation logic + pass + + def test_step(self, batch, batch_idx, dataloader_idx=0): + # Implement test logic + pass + + def predict_step(self, batch, batch_idx, dataloader_idx=0): + # Implement prediction logic + pass +``` + +## Configuration Example + +The base task uses the following configuration sections: + +```yaml +optimizer: + learning_rate: 0.001 + beta1: 0.9 + beta2: 0.999 + lr_decay: 0.7 + lr_patience: 5 + +data: + networks: + - case14_ieee + - case118_ieee +``` + +## Related + +- [Reconstruction Task](reconstruction_task.md): Base class for reconstruction tasks +- [Power Flow Task](power_flow.md): Concrete implementation for power flow +- [Optimal Power Flow Task](optimal_power_flow.md): Concrete implementation for OPF +- [State Estimation Task](state_estimation.md): Concrete implementation for state estimation diff --git a/docs/tasks/feature_reconstruction.md b/docs/tasks/feature_reconstruction.md index 39e0823..fbde3ea 100644 --- a/docs/tasks/feature_reconstruction.md +++ b/docs/tasks/feature_reconstruction.md @@ -1,3 +1,185 @@ -# Feature Reconstruction Task +# Task Classes Overview -::: gridfm_graphkit.tasks.feature_reconstruction_task.FeatureReconstructionTask +GridFM-GraphKit provides a hierarchical task system for power grid analysis. All tasks inherit from a common base class and share core functionality while implementing domain-specific logic. + +## Task Hierarchy + +``` +BaseTask (Abstract) + └── ReconstructionTask + ├── PowerFlowTask + ├── OptimalPowerFlowTask + └── StateEstimationTask +``` + +## Available Task Classes + +### Base Classes + +- **[BaseTask](base_task.md)**: Abstract base class providing common functionality for all tasks + - Optimizer configuration + - Learning rate scheduling + - Normalization statistics logging + - Abstract method definitions + +- **[ReconstructionTask](reconstruction_task.md)**: Base class for feature reconstruction tasks + - Model integration + - Loss function handling + - Shared training/validation logic + - Test output management + +### Concrete Task Implementations + +- **[PowerFlowTask](power_flow.md)**: Power flow analysis + - Computes voltage profiles and power flows + - Physics-based validation with Power Balance Error (PBE) + - Separate metrics for PQ, PV, and REF buses + - Detailed per-bus predictions + +- **[OptimalPowerFlowTask](optimal_power_flow.md)**: Optimal power flow with economic optimization + - Minimizes generation costs + - Tracks optimality gap + - Monitors constraint violations (thermal, voltage, angle) + - Evaluates reactive power limits + +- **[StateEstimationTask](state_estimation.md)**: State estimation from noisy measurements + - Handles measurement noise and outliers + - Separate evaluation for outliers, masked values, and clean measurements + - Correlation analysis between predictions, measurements, and targets + +## Quick Reference + +### Method Overview + +All task classes implement the following core methods: + +| Method | Purpose | Implemented In | +|--------|---------|----------------| +| `__init__` | Initialize task with config and normalizers | All classes | +| `forward` | Forward pass through model | ReconstructionTask+ | +| `training_step` | Execute one training step | ReconstructionTask+ | +| `validation_step` | Execute one validation step | ReconstructionTask+ | +| `test_step` | Execute one test step | Concrete tasks | +| `predict_step` | Execute one prediction step | Concrete tasks | +| `on_fit_start` | Save normalization stats before training | BaseTask | +| `on_test_end` | Generate reports and plots after testing | Concrete tasks | +| `configure_optimizers` | Setup optimizer and scheduler | BaseTask | + +### Task Selection + +Tasks are automatically selected based on your YAML configuration: + +```yaml +task: + task_name: PowerFlow # or OptimalPowerFlow, StateEstimation +``` + +The task registry automatically instantiates the correct task class based on the `task_name` field. + +## Common Features + +All tasks share these features: + +### 1. Distributed Training Support +- Multi-GPU training with proper metric synchronization +- Rank 0 handles logging and file I/O +- Automatic gathering of test outputs across ranks + +### 2. Comprehensive Logging +- Training and validation metrics logged to MLflow or TensorBoard +- Automatic hyperparameter tracking +- Normalization statistics saved for reproducibility + +### 3. Test Outputs +- CSV reports with detailed metrics +- Visualization plots (when `verbose=True`) +- Per-dataset analysis for multiple test sets + +### 4. Physics-Based Evaluation +- Power balance error computation +- Branch flow calculations +- Residual analysis by bus type + +## Configuration + +### Basic Configuration + +```yaml +task: + task_name: PowerFlow + verbose: true + +training: + batch_size: 64 + epochs: 100 + losses: ["MaskedMSE", "PBE"] + loss_weights: [0.01, 0.99] + +optimizer: + learning_rate: 0.001 + beta1: 0.9 + beta2: 0.999 + lr_decay: 0.7 + lr_patience: 5 +``` + +### Task-Specific Options + +Each task may have additional configuration options. See the individual task documentation for details: + +- [Power Flow Configuration](power_flow.md#configuration-example) +- [Optimal Power Flow Configuration](optimal_power_flow.md#configuration-example) +- [State Estimation Configuration](state_estimation.md#configuration-example) + +## Creating Custom Tasks + +To create a custom task, extend `ReconstructionTask` or `BaseTask`: + +```python +from gridfm_graphkit.tasks.reconstruction_tasks import ReconstructionTask +from gridfm_graphkit.io.registries import TASK_REGISTRY + +@TASK_REGISTRY.register("MyCustomTask") +class MyCustomTask(ReconstructionTask): + def __init__(self, args, data_normalizers): + super().__init__(args, data_normalizers) + # Add custom initialization + + def test_step(self, batch, batch_idx, dataloader_idx=0): + # Implement custom test logic + output, loss_dict = self.shared_step(batch) + + # Add custom metrics + custom_metric = self.compute_custom_metric(output, batch) + loss_dict["Custom Metric"] = custom_metric + + # Log metrics + for metric, value in loss_dict.items(): + self.log(f"{dataset_name}/{metric}", value) + + return loss_dict["loss"] + + def predict_step(self, batch, batch_idx, dataloader_idx=0): + # Implement custom prediction logic + output, _ = self.shared_step(batch) + return {"predictions": output} + + def on_test_end(self): + # Custom analysis and visualization + # Generate reports, plots, etc. + super().on_test_end() +``` + +Then use it in your configuration: + +```yaml +task: + task_name: MyCustomTask +``` + +## Related Documentation + +- [Loss Functions](../training/loss.md): Available loss functions and their configuration +- [Data Modules](../datasets/data_modules.md): Data loading and preprocessing +- [Models](../models/models.md): Available model architectures +- [Quick Start Guide](../quick_start/quick_start.md): Getting started with training diff --git a/docs/tasks/optimal_power_flow.md b/docs/tasks/optimal_power_flow.md new file mode 100644 index 0000000..3d13a57 --- /dev/null +++ b/docs/tasks/optimal_power_flow.md @@ -0,0 +1,12 @@ +# Optimal Power Flow Task + +## OptimalPowerFlowTask Class + +::: gridfm_graphkit.tasks.opf_task.OptimalPowerFlowTask + options: + show_root_heading: true + show_source: true + members: + - __init__ + - test_step + - on_test_end diff --git a/docs/tasks/power_flow.md b/docs/tasks/power_flow.md new file mode 100644 index 0000000..8912a26 --- /dev/null +++ b/docs/tasks/power_flow.md @@ -0,0 +1,12 @@ +# Power Flow Task + +## PowerFlowTask Class + +::: gridfm_graphkit.tasks.pf_task.PowerFlowTask + options: + show_root_heading: true + show_source: true + members: + - __init__ + - test_step + - on_test_end diff --git a/docs/tasks/reconstruction_task.md b/docs/tasks/reconstruction_task.md new file mode 100644 index 0000000..54e9e5a --- /dev/null +++ b/docs/tasks/reconstruction_task.md @@ -0,0 +1,293 @@ +# Reconstruction Task + +The `ReconstructionTask` class is a concrete implementation of `BaseTask` that provides the foundation for node feature reconstruction on power grid graphs. It wraps a GridFM model and defines the training, validation, and testing logic for reconstructing masked node features. + +## Overview + +`ReconstructionTask` serves as the base class for all reconstruction-based tasks in GridFM-GraphKit, including: + +- Power Flow (PF) +- Optimal Power Flow (OPF) +- State Estimation (SE) + +It provides: + +- **Model integration**: Loads and wraps the GridFM model +- **Loss function handling**: Configures and applies loss functions +- **Shared training logic**: Common training and validation steps +- **Test output management**: Collects and manages test outputs for analysis + +## ReconstructionTask Class + +::: gridfm_graphkit.tasks.reconstruction_tasks.ReconstructionTask + options: + show_root_heading: true + show_source: true + members: + - __init__ + - forward + - shared_step + - training_step + - validation_step + - on_test_end + +## Methods + +### `__init__(args, data_normalizers)` + +Initialize the reconstruction task with model, loss function, and configuration. + +**Parameters:** + +- `args` (NestedNamespace): Experiment configuration with fields like: + - `training.batch_size`: Batch size for training + - `optimizer.*`: Optimizer configuration + - `model.*`: Model architecture configuration + - `training.losses`: List of loss functions to use + - `data.networks`: List of network names +- `data_normalizers` (list): One normalizer per dataset for feature normalization/denormalization + +**Attributes Set:** + +- `self.model`: GridFM model loaded via `load_model()` +- `self.loss_fn`: Loss function resolved from configuration via `get_loss_function()` +- `self.batch_size`: Training batch size +- `self.test_outputs`: Dictionary to store test outputs per dataset (keyed by dataloader index) + +**Example:** + +```python +task = ReconstructionTask(args, data_normalizers) +``` + +--- + +### `forward(x_dict, edge_index_dict, edge_attr_dict, mask_dict)` + +Forward pass through the model. + +**Parameters:** + +- `x_dict` (dict): Node features dictionary with keys like `"bus"`, `"gen"` +- `edge_index_dict` (dict): Edge indices dictionary for heterogeneous edges +- `edge_attr_dict` (dict): Edge attributes dictionary +- `mask_dict` (dict): Masking dictionary indicating which features are masked + +**Returns:** + +- Model output dictionary with predicted node features + +**Example:** + +```python +output = task.forward( + x_dict=batch.x_dict, + edge_index_dict=batch.edge_index_dict, + edge_attr_dict=batch.edge_attr_dict, + mask_dict=batch.mask_dict +) +``` + +--- + +### `shared_step(batch)` + +Common logic for training and validation steps. + +**Parameters:** + +- `batch`: A batch from the dataloader containing: + - `x_dict`: Input node features + - `y_dict`: Target node features + - `edge_index_dict`: Edge connectivity + - `edge_attr_dict`: Edge attributes + - `mask_dict`: Feature masks + +**Returns:** + +- `output` (dict): Model predictions +- `loss_dict` (dict): Dictionary containing: + - `"loss"`: Total loss value + - Additional loss components (if applicable) + +**Behavior:** + +1. Performs forward pass through the model +2. Computes loss using the configured loss function +3. Returns both predictions and loss dictionary + +**Example:** + +```python +output, loss_dict = task.shared_step(batch) +total_loss = loss_dict["loss"] +``` + +--- + +### `training_step(batch)` + +Execute one training step. + +**Parameters:** + +- `batch`: Training batch from dataloader + +**Returns:** + +- Loss tensor for backpropagation + +**Logged Metrics:** + +- `"Training Loss"`: Total training loss +- `"Learning Rate"`: Current learning rate + +**Logging Configuration:** + +- `batch_size`: Number of graphs in batch +- `sync_dist=False`: No synchronization across GPUs during training +- `on_epoch=False`: Log per step, not per epoch +- `on_step=True`: Log at each training step +- `prog_bar=False`: Don't show in progress bar +- `logger=True`: Send to logger (e.g., MLflow) + +--- + +### `validation_step(batch, batch_idx)` + +Execute one validation step. + +**Parameters:** + +- `batch`: Validation batch from dataloader +- `batch_idx` (int): Index of the current batch + +**Returns:** + +- Loss tensor + +**Logged Metrics:** + +- `"Validation loss"`: Total validation loss +- Additional loss components (if multiple losses are used) + +**Logging Configuration:** + +- `batch_size`: Number of graphs in batch +- `sync_dist=True`: Synchronize metrics across GPUs +- `on_epoch=True`: Aggregate and log at epoch end +- `on_step=False`: Don't log individual steps +- `logger=True`: Send to logger + +**Note:** The validation loss is monitored by the learning rate scheduler for automatic learning rate reduction. + +--- + +### `on_test_end()` + +Called at the end of testing. Clears stored test outputs. + +**Behavior:** + +- Clears the `self.test_outputs` dictionary +- Only executes on rank 0 in distributed training (via `@rank_zero_only` decorator) +- Subclasses typically override this to add custom analysis, plotting, and CSV generation + +**Note:** This is a minimal implementation. Task-specific subclasses (PowerFlowTask, OptimalPowerFlowTask, StateEstimationTask) override this method to: + +- Generate detailed metrics CSV files +- Create visualization plots +- Save analysis results + +--- + +## Usage + +`ReconstructionTask` can be used directly for simple reconstruction tasks, but is typically subclassed for specific power system tasks: + +```python +from gridfm_graphkit.tasks.reconstruction_tasks import ReconstructionTask + +# Direct usage (simple reconstruction) +task = ReconstructionTask(args, data_normalizers) + +# Or create a subclass for custom behavior +class CustomReconstructionTask(ReconstructionTask): + def test_step(self, batch, batch_idx, dataloader_idx=0): + # Custom test logic + output, loss_dict = self.shared_step(batch) + # Add custom metrics + return loss_dict["loss"] + + def on_test_end(self): + # Custom analysis and visualization + super().on_test_end() +``` + +## Configuration Example + +```yaml +task: + task_name: Reconstruction # Or PowerFlow, OptimalPowerFlow, StateEstimation + +model: + type: GNS_heterogeneous + hidden_size: 48 + num_layers: 12 + attention_head: 8 + +training: + batch_size: 64 + epochs: 100 + losses: + - MaskedMSE + loss_weights: + - 1.0 + +optimizer: + learning_rate: 0.001 + beta1: 0.9 + beta2: 0.999 + lr_decay: 0.7 + lr_patience: 5 +``` + +## Loss Functions + +The reconstruction task supports various loss functions configured via the YAML file: + +- **MaskedMSE**: Mean squared error on masked features only +- **MaskedBusMSE**: MSE specifically for bus node features +- **LayeredWeightedPhysics**: Physics-based loss with layer-wise weighting +- **PBE**: Power Balance Error loss + +Multiple losses can be combined with weights: + +```yaml +training: + losses: + - LayeredWeightedPhysics + - MaskedBusMSE + loss_weights: + - 0.1 + - 0.9 + loss_args: + - base_weight: 0.5 + - {} +``` + +## Subclasses + +The following task classes extend `ReconstructionTask`: + +- **[PowerFlowTask](power_flow.md)**: Adds power flow-specific metrics and physics validation +- **[OptimalPowerFlowTask](optimal_power_flow.md)**: Adds economic optimization metrics and constraint violation tracking +- **[StateEstimationTask](state_estimation.md)**: Adds measurement-based estimation and outlier handling + +## Related + +- [Base Task](base_task.md): Abstract base class for all tasks +- [Power Flow Task](power_flow.md): Power flow analysis implementation +- [Optimal Power Flow Task](optimal_power_flow.md): OPF optimization implementation +- [State Estimation Task](state_estimation.md): State estimation implementation +- [Loss Functions](../training/loss.md): Available loss functions diff --git a/docs/tasks/state_estimation.md b/docs/tasks/state_estimation.md new file mode 100644 index 0000000..c3adbbc --- /dev/null +++ b/docs/tasks/state_estimation.md @@ -0,0 +1,11 @@ +# State Estimation Task + +::: gridfm_graphkit.tasks.se_task.StateEstimationTask + options: + show_root_heading: true + show_source: true + members: + - __init__ + - test_step + - on_test_end + - predict_step diff --git a/docs/training/loss.md b/docs/training/loss.md index 5cde707..de56d4b 100644 --- a/docs/training/loss.md +++ b/docs/training/loss.md @@ -1,49 +1,47 @@ # Loss Functions -### `Power Balance Equation Loss` +## Base Loss -$$ -\mathcal{L}_{\text{PBE}} = \frac{1}{N} \sum_{i=1}^N \left| (P_{G,i} - P_{D,i}) + j(Q_{G,i} - Q_{D,i}) - S_{\text{injection}, i} \right| -$$ - -::: gridfm_graphkit.training.loss.PBELoss +::: gridfm_graphkit.training.loss.BaseLoss --- -### `Mean Squared Error Loss` - -$$ -\mathcal{L}_{\text{MSE}} = \frac{1}{N} \sum_{i=1}^N (y_i - \hat{y}_i)^2 -$$ +## Mean Squared Error Loss ::: gridfm_graphkit.training.loss.MSELoss --- -### `Masked Mean Squared Error Loss` - -$$ -\mathcal{L}_{\text{MaskedMSE}} = \frac{1}{|M|} \sum_{i \in M} (y_i - \hat{y}_i)^2 -$$ +## Masked Mean Squared Error Loss ::: gridfm_graphkit.training.loss.MaskedMSELoss --- -### `Scaled Cosine Error Loss` - -$$ -\mathcal{L}_{\text{SCE}} = \frac{1}{N} \sum_{i=1}^N \left(1 - \frac{\hat{y}^T_i \cdot y_i}{\|\hat{y}_i\| \|y_i\|}\right)^\alpha \text{ , } \alpha \geq 1 -$$ +## Masked Generator MSE Loss -::: gridfm_graphkit.training.loss.SCELoss +::: gridfm_graphkit.training.loss.MaskedGenMSE --- -### `Mixed Loss` +## Masked Bus MSE Loss -$$ -\mathcal{L}_{\text{Mixed}} = \sum_{m=1}^M w_m \cdot \mathcal{L}_m -$$ +::: gridfm_graphkit.training.loss.MaskedBusMSE + +--- + +## Mixed Loss ::: gridfm_graphkit.training.loss.MixedLoss + +--- + +## Layered Weighted Physics Loss + +::: gridfm_graphkit.training.loss.LayeredWeightedPhysicsLoss + +--- + +## Loss Per Dimension + +::: gridfm_graphkit.training.loss.LossPerDim diff --git a/gridfm_graphkit/__main__.py b/gridfm_graphkit/__main__.py index bafaf24..c01d57a 100644 --- a/gridfm_graphkit/__main__.py +++ b/gridfm_graphkit/__main__.py @@ -1,6 +1,6 @@ import argparse from datetime import datetime -from gridfm_graphkit.cli import main_cli +from gridfm_graphkit.cli import main_cli, benchmark_cli def main(): @@ -18,12 +18,36 @@ def main(): train_parser.add_argument("--run_name", type=str, default="run") train_parser.add_argument("--log_dir", type=str, default="mlruns") train_parser.add_argument("--data_path", type=str, default="data") + train_parser.add_argument( + "--dataset_wrapper", + type=str, + default=None, + help="Registered name of a dataset wrapper (see DATASET_WRAPPER_REGISTRY), e.g. SharedMemoryCacheDataset", + ) + train_parser.add_argument( + "--plugins", + nargs="*", + default=[], + help="Python packages to import for plugin registration, e.g. gridfm_graphkit_ee", + ) + train_parser.add_argument( + "--num_workers", + type=int, + default=None, + help="Override data.workers from the YAML config. Use 0 to debug worker crashes.", + ) + train_parser.add_argument( + "--dataset_wrapper_cache_dir", + type=str, + default=None, + help="Directory for the dataset wrapper's disk cache.", + ) train_parser.add_argument( "--profiler", type=str, default=None, choices=["simple", "advanced", "pytorch"], - help="Enable Lightning profiler: 'simple', 'advanced', or 'pytorch'.", + help="Enable Lightning profiler.", ) # ---- FINETUNE SUBCOMMAND ---- @@ -34,12 +58,36 @@ def main(): finetune_parser.add_argument("--run_name", type=str, default="run") finetune_parser.add_argument("--log_dir", type=str, default="mlruns") finetune_parser.add_argument("--data_path", type=str, default="data") + finetune_parser.add_argument( + "--dataset_wrapper", + type=str, + default=None, + help="Registered name of a dataset wrapper.", + ) + finetune_parser.add_argument( + "--plugins", + nargs="*", + default=[], + help="Python packages to import for plugin registration.", + ) + finetune_parser.add_argument( + "--num_workers", + type=int, + default=None, + help="Override data.workers from the YAML config.", + ) + finetune_parser.add_argument( + "--dataset_wrapper_cache_dir", + type=str, + default=None, + help="Directory for the dataset wrapper's disk cache.", + ) finetune_parser.add_argument( "--profiler", type=str, default=None, choices=["simple", "advanced", "pytorch"], - help="Enable Lightning profiler: 'simple', 'advanced', or 'pytorch'.", + help="Enable Lightning profiler.", ) # ---- EVALUATE SUBCOMMAND ---- @@ -52,58 +100,94 @@ def main(): "--normalizer_stats", type=str, default=None, - help="Path to normalizer_stats.pt from a training run. " - "Restores normalizers from saved stats instead of re-fitting.", + help="Path to normalizer_stats.pt from a training run.", ) evaluate_parser.add_argument("--config", type=str, required=True) evaluate_parser.add_argument("--exp_name", type=str, default=exp_name) evaluate_parser.add_argument("--run_name", type=str, default="run") evaluate_parser.add_argument("--log_dir", type=str, default="mlruns") evaluate_parser.add_argument("--data_path", type=str, default="data") + evaluate_parser.add_argument( + "--dataset_wrapper", + type=str, + default=None, + help="Registered name of a dataset wrapper.", + ) + evaluate_parser.add_argument( + "--plugins", + nargs="*", + default=[], + help="Python packages to import for plugin registration.", + ) + evaluate_parser.add_argument( + "--num_workers", + type=int, + default=None, + help="Override data.workers from the YAML config.", + ) + evaluate_parser.add_argument( + "--dataset_wrapper_cache_dir", + type=str, + default=None, + help="Directory for the dataset wrapper's disk cache.", + ) evaluate_parser.add_argument( "--profiler", type=str, default=None, choices=["simple", "advanced", "pytorch"], - help="Enable Lightning profiler: 'simple', 'advanced', or 'pytorch'.", + help="Enable Lightning profiler.", ) evaluate_parser.add_argument( "--compute_dc_ac_metrics", action="store_true", - help="Compute ground-truth AC/DC power balance metrics on the test split.", ) evaluate_parser.add_argument( "--save_output", action="store_true", - help="Save per-bus predictions CSV via the predict step.", ) + # ---- PREDICT SUBCOMMAND ---- - predict_parser = subparsers.add_parser("predict", help="Evaluate model performance") - predict_parser.add_argument("--model_path", type=str, required=None) - predict_parser.add_argument( - "--normalizer_stats", - type=str, - default=None, - help="Path to normalizer_stats.pt from a training run. " - "Restores normalizers from saved stats instead of re-fitting.", - ) + predict_parser = subparsers.add_parser("predict", help="Run prediction") + predict_parser.add_argument("--model_path", type=str, required=False) + predict_parser.add_argument("--normalizer_stats", type=str, default=None) predict_parser.add_argument("--config", type=str, required=True) predict_parser.add_argument("--exp_name", type=str, default=exp_name) predict_parser.add_argument("--run_name", type=str, default="run") predict_parser.add_argument("--log_dir", type=str, default="mlruns") predict_parser.add_argument("--data_path", type=str, default="data") + predict_parser.add_argument("--dataset_wrapper", type=str, default=None) + predict_parser.add_argument("--plugins", nargs="*", default=[]) + predict_parser.add_argument("--num_workers", type=int, default=None) + predict_parser.add_argument("--dataset_wrapper_cache_dir", type=str, default=None) predict_parser.add_argument("--output_path", type=str, default="data") predict_parser.add_argument( "--profiler", type=str, default=None, choices=["simple", "advanced", "pytorch"], - help="Enable Lightning profiler: 'simple', 'advanced', or 'pytorch'.", ) + # ---- BENCHMARK SUBCOMMAND ---- + benchmark_parser = subparsers.add_parser( + "benchmark", + help="Benchmark train-dataloader iteration speed", + ) + benchmark_parser.add_argument("--config", type=str, required=True) + benchmark_parser.add_argument("--data_path", type=str, default="data") + benchmark_parser.add_argument("--epochs", type=int, default=3) + benchmark_parser.add_argument("--dataset_wrapper", type=str, default=None) + benchmark_parser.add_argument("--dataset_wrapper_cache_dir", type=str, default=None) + benchmark_parser.add_argument("--num_workers", type=int, default=None) + benchmark_parser.add_argument("--plugins", nargs="*", default=[]) + args = parser.parse_args() - main_cli(args) + + if args.command == "benchmark": + benchmark_cli(args) + else: + main_cli(args) if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/gridfm_graphkit/cli.py b/gridfm_graphkit/cli.py index 23fdf08..44ee1e3 100644 --- a/gridfm_graphkit/cli.py +++ b/gridfm_graphkit/cli.py @@ -1,8 +1,11 @@ from gridfm_graphkit.datasets.hetero_powergrid_datamodule import LitGridHeteroDataModule from gridfm_graphkit.io.param_handler import NestedNamespace +from gridfm_graphkit.io.registries import DATASET_WRAPPER_REGISTRY from gridfm_graphkit.training.callbacks import SaveBestModelStateDict +import importlib import numpy as np import os +import time import yaml import torch import pandas as pd @@ -15,6 +18,86 @@ import lightning as L +def _load_plugins(plugins: list[str]) -> None: + """Import plugin packages so their registry decorators fire.""" + for plugin_pkg in plugins: + try: + importlib.import_module(plugin_pkg) + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + f"Plugin package '{plugin_pkg}' could not be imported: {e}. " + "Make sure it is installed in the current environment.", + ) from e + + +def _validate_dataset_wrapper(name: str | None) -> None: + """Raise a helpful error if *name* is not registered in DATASET_WRAPPER_REGISTRY.""" + if name is None: + return + if name not in DATASET_WRAPPER_REGISTRY: + available = list(DATASET_WRAPPER_REGISTRY) + raise KeyError( + f"Dataset wrapper '{name}' is not registered. " + f"Available wrappers: {available}. " + "If it lives in a plugin package, pass it via --plugins.", + ) + + +def benchmark_cli(args): + """Benchmark train-dataloader iteration speed over one or more epochs.""" + with open(args.config, "r") as f: + base_config = yaml.safe_load(f) + + config_args = NestedNamespace(**base_config) + + num_workers_override = getattr(args, "num_workers", None) + if num_workers_override is not None: + config_args.data.workers = num_workers_override + + _load_plugins(getattr(args, "plugins", [])) + + dataset_wrapper = getattr(args, "dataset_wrapper", None) + dataset_wrapper_cache_dir = getattr(args, "dataset_wrapper_cache_dir", None) + _validate_dataset_wrapper(dataset_wrapper) + + print("Setting up datamodule...") + t0 = time.perf_counter() + dm = LitGridHeteroDataModule( + config_args, + args.data_path, + dataset_wrapper=dataset_wrapper, + dataset_wrapper_cache_dir=dataset_wrapper_cache_dir, + ) + dm.setup(stage="fit") + setup_time = time.perf_counter() - t0 + print(f" Setup time : {setup_time:.2f}s") + + loader = dm.train_dataloader() + num_batches = len(loader) + print(f" Train batches : {num_batches}") + print(f" Batch size : {config_args.training.batch_size}") + print(f" Workers : {config_args.data.workers}") + print(f" Dataset wrapper : {dataset_wrapper or 'none'}") + print() + + epoch_times = [] + for epoch in range(args.epochs): + t_start = time.perf_counter() + for _batch in loader: + pass + elapsed = time.perf_counter() - t_start + per_batch = elapsed / num_batches if num_batches > 0 else 0.0 + epoch_times.append(elapsed) + print( + f"Epoch {epoch:>3}: {elapsed:7.3f}s total " + f"{per_batch:.4f}s/batch ({num_batches} batches)", + ) + + if args.epochs > 1: + avg = sum(epoch_times) / len(epoch_times) + print(f"\nAverage over {args.epochs} epochs: {avg:.3f}s") + + def get_training_callbacks(args): early_stop_callback = EarlyStopping( monitor="Validation loss", @@ -55,10 +138,23 @@ def main_cli(args): L.seed_everything(config_args.seed, workers=True) normalizer_stats_path = getattr(args, "normalizer_stats", None) + dataset_wrapper = getattr(args, "dataset_wrapper", None) + dataset_wrapper_cache_dir = getattr(args, "dataset_wrapper_cache_dir", None) + + # CLI --num_workers overrides the YAML value (useful for debugging with 0) + num_workers_override = getattr(args, "num_workers", None) + if num_workers_override is not None: + config_args.data.workers = num_workers_override + + _load_plugins(getattr(args, "plugins", [])) + _validate_dataset_wrapper(dataset_wrapper) + litGrid = LitGridHeteroDataModule( config_args, args.data_path, normalizer_stats_path=normalizer_stats_path, + dataset_wrapper=dataset_wrapper, + dataset_wrapper_cache_dir=dataset_wrapper_cache_dir, ) model = get_task(config_args, litGrid.data_normalizers) if args.command != "train": diff --git a/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py b/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py index 4ac0125..acd45aa 100644 --- a/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py +++ b/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py @@ -1,9 +1,11 @@ import json import torch +import os from torch_geometric.loader import DataLoader from torch.utils.data import ConcatDataset from torch.utils.data import Subset import torch.distributed as dist +from gridfm_graphkit.io.registries import DATASET_WRAPPER_REGISTRY from gridfm_graphkit.io.param_handler import ( NestedNamespace, load_normalizer, @@ -17,7 +19,6 @@ import numpy as np import random import warnings -import os import lightning as L from typing import List from lightning.pytorch.loggers import MLFlowLogger @@ -87,9 +88,13 @@ def __init__( args: NestedNamespace, data_dir: str = "./data", normalizer_stats_path: str = None, + dataset_wrapper: str = None, + dataset_wrapper_cache_dir: str = None, ): super().__init__() self.data_dir = data_dir + self.dataset_wrapper = dataset_wrapper + self.dataset_wrapper_cache_dir = dataset_wrapper_cache_dir self.batch_size = int(args.training.batch_size) self.split_by_load_scenario_idx = getattr( args.data, @@ -149,6 +154,7 @@ def setup(self, stage: str): data_normalizer=data_normalizer, transform=get_task_transforms(args=self.args), ) + self.datasets.append(dataset) num_scenarios = self.args.data.scenarios[i] @@ -171,6 +177,10 @@ def setup(self, stage: str): dataset = Subset(dataset, subset_indices) + if self.dataset_wrapper is not None: + wrapper_cls = DATASET_WRAPPER_REGISTRY.get(self.dataset_wrapper) + dataset = wrapper_cls(dataset, cache_dir=self.dataset_wrapper_cache_dir) + # Random seed set before every split, same as above np.random.seed(self.args.seed) if self.split_by_load_scenario_idx: @@ -229,6 +239,11 @@ def setup(self, stage: str): saved_stats, ) + # Populate the wrapper cache now that the normalizer is fitted, + # so transform() has BaseMVA set when __getitem__ is called. + if self.dataset_wrapper is not None and hasattr(dataset, "_setup_cache"): + dataset._setup_cache() + self.train_datasets.append(train_dataset) self.val_datasets.append(val_dataset) self.test_datasets.append(test_dataset) @@ -244,7 +259,11 @@ def setup(self, stage: str): is_rank0 = ( not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0 ) - if is_rank0 and self.trainer is not None and self.trainer.logger is not None: + if ( + is_rank0 + and self.trainer is not None + and getattr(self.trainer, "logger", None) is not None + ): logger = self.trainer.logger if isinstance(logger, MLFlowLogger): log_dir = os.path.join( @@ -343,13 +362,21 @@ def save_scenario_splits(self, log_dir: str): with open(splits_path, "w") as f: json.dump(splits, f, indent=2) + def _dataloader_kwargs(self): + num_workers = self.args.data.workers + kwargs = dict( + num_workers=num_workers, + pin_memory=torch.cuda.is_available(), + persistent_workers=num_workers > 0, + ) + return kwargs + def train_dataloader(self): return DataLoader( self.train_dataset_multi, batch_size=self.batch_size, shuffle=True, - num_workers=self.args.data.workers, - pin_memory=True, + **self._dataloader_kwargs(), ) def val_dataloader(self): @@ -357,8 +384,7 @@ def val_dataloader(self): self.val_dataset_multi, batch_size=self.batch_size, shuffle=False, - num_workers=self.args.data.workers, - pin_memory=True, + **self._dataloader_kwargs(), ) def test_dataloader(self): @@ -367,8 +393,7 @@ def test_dataloader(self): i, batch_size=self.batch_size, shuffle=False, - num_workers=self.args.data.workers, - pin_memory=True, + **self._dataloader_kwargs(), ) for i in self.test_datasets ] @@ -379,8 +404,7 @@ def predict_dataloader(self): i, batch_size=self.batch_size, shuffle=False, - num_workers=self.args.data.workers, - pin_memory=True, + **self._dataloader_kwargs(), ) for i in self.test_datasets ] diff --git a/gridfm_graphkit/io/registries.py b/gridfm_graphkit/io/registries.py index 27d56b7..32feb20 100644 --- a/gridfm_graphkit/io/registries.py +++ b/gridfm_graphkit/io/registries.py @@ -43,3 +43,4 @@ def __len__(self): TASK_REGISTRY = Registry("task") TRANSFORM_REGISTRY = Registry("transform") PHYSICS_DECODER_REGISTRY = Registry("physics_decoder") +DATASET_WRAPPER_REGISTRY = Registry("dataset_wrapper") diff --git a/gridfm_graphkit/tasks/opf_task.py b/gridfm_graphkit/tasks/opf_task.py index b28c5a0..06d938d 100644 --- a/gridfm_graphkit/tasks/opf_task.py +++ b/gridfm_graphkit/tasks/opf_task.py @@ -256,8 +256,8 @@ def test_step(self, batch, batch_idx, dataloader_idx=0): loss_dict["Opt gap"] = optimality_gap loss_dict["MSE PG"] = mse_PG[PG_H] - loss_dict["Branch termal violation from"] = mean_thermal_violation_forward - loss_dict["Branch termal violation to"] = mean_thermal_violation_reverse + loss_dict["Branch thermal violation from"] = mean_thermal_violation_forward + loss_dict["Branch thermal violation to"] = mean_thermal_violation_reverse loss_dict["Branch voltage angle difference violations"] = ( branch_angle_violation_mean ) diff --git a/integrationtests/test_base_set.py b/integrationtests/test_base_set.py new file mode 100644 index 0000000..90da468 --- /dev/null +++ b/integrationtests/test_base_set.py @@ -0,0 +1,160 @@ +import pytest +import subprocess +import os +import glob +import pandas as pd +import yaml +import urllib.request +import shutil + + +def execute_and_live_output(cmd) -> None: + subprocess.run(cmd, text=True, shell=True, check=True) + + +def prepare_config(): + """ + Download default.yaml from gridfm-datakit repo and modify it with test parameters. + """ + config_url = "https://raw.githubusercontent.com/gridfm/gridfm-datakit/refs/heads/main/scripts/config/default.yaml" + config_path = "integrationtests/default.yaml" + + print(f"Downloading config from {config_url}...") + with urllib.request.urlopen(config_url) as response: + config_content = response.read().decode("utf-8") + + config = yaml.safe_load(config_content) + + config["network"]["name"] = "case14_ieee" + config["load"]["scenarios"] = 10000 + config["topology_perturbation"]["n_topology_variants"] = 2 + + with open(config_path, "w") as f: + yaml.dump(config, f, default_flow_style=False, sort_keys=False) + + print(f"Config prepared at {config_path} with:") + print(f" - network.name: {config['network']['name']}") + print(f" - load.scenarios: {config['load']['scenarios']}") + print( + f" - topology_perturbation.n_topology_variants: " + f"{config['topology_perturbation']['n_topology_variants']}", + ) + + return config_path + + +def prepare_training_config(): + """ + Modify the training config to set epochs to 2 for testing. + """ + config_path = "examples/config/HGNS_PF_datakit_case14.yaml" + + with open(config_path, "r") as f: + config = yaml.safe_load(f) + + if "training" not in config: + config["training"] = {} + + config["training"]["epochs"] = 2 + + with open(config_path, "w") as f: + yaml.dump(config, f, default_flow_style=False, sort_keys=False) + + print(f"Training config updated: epochs set to {config['training']['epochs']}") + + return config_path + + +@pytest.fixture +def cleanup_test_artifacts(): + """ + Backup modified files and remove generated artifacts after the test. + """ + training_config = " " + backup_config = training_config + ".bak" + + if os.path.exists(training_config): + shutil.copy2(training_config, backup_config) + + yield + + # Restore training config + if os.path.exists(backup_config): + shutil.move(backup_config, training_config) + + # Remove downloaded config + config_file = "integrationtests/default.yaml" + if os.path.exists(config_file): + os.remove(config_file) + + # Remove generated directories + for d in ["data_out", "logs"]: + if os.path.exists(d): + shutil.rmtree(d, ignore_errors=True) + + +def test_train(cleanup_test_artifacts): + """ + Integration test for gridfm-datakit data generation and gridfm-graphkit training. + + Steps: + 1. Generate power grid data using gridfm-datakit + 2. Train a model using gridfm-graphkit + 3. Validate the PBE Mean metric + """ + + data_dir = "data_out" + + if not os.path.exists(data_dir) or not os.listdir(data_dir): + print("Data directory not found or empty, generating data...") + + config_path = prepare_config() + + execute_and_live_output(f"gridfm_datakit generate {config_path}") + else: + print(f"Data directory '{data_dir}' already exists, skipping generation.") + + training_config_path = prepare_training_config() + + execute_and_live_output( + f"gridfm_graphkit train " + f"--config {training_config_path} " + f"--data_path data_out/ " + f"--exp_name exp1 " + f"--run_name run1 " + f"--log_dir logs", + ) + + log_base = "logs" + + exp_dirs = glob.glob(os.path.join(log_base, "*")) + assert len(exp_dirs) > 0, "No experiment directories found in logs/" + + latest_exp_dir = sorted(exp_dirs, key=os.path.getctime)[-1] + + run_dirs = glob.glob(os.path.join(latest_exp_dir, "*")) + assert len(run_dirs) > 0, f"No run directories found in {latest_exp_dir}" + + latest_run_dir = max(run_dirs, key=os.path.getmtime) + + metrics_file = os.path.join( + latest_run_dir, + "artifacts", + "test", + "case14_ieee_metrics.csv", + ) + + assert os.path.exists(metrics_file), f"Metrics file not found: {metrics_file}" + + df = pd.read_csv(metrics_file) + + pbe_mean_row = df[df["Metric"] == "PBE Mean"] + assert len(pbe_mean_row) > 0, "PBE Mean metric not found in CSV" + + pbe_mean_value = float(pbe_mean_row.iloc[0]["Value"]) + + assert 1.1 <= pbe_mean_value <= 2.9, ( + f"PBE Mean value {pbe_mean_value} is outside acceptable range [1.1, 2.9]" + ) + + print(f"PBE Mean value {pbe_mean_value} is within acceptable range [1.1, 2.9]") diff --git a/mkdocs.yml b/mkdocs.yml index afc3359..6581214 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -19,7 +19,12 @@ nav: - Data Modules: datasets/data_modules.md - Transforms: datasets/transforms.md - Tasks: - - Feature Reconstruction: tasks/feature_reconstruction.md + - Overview: tasks/feature_reconstruction.md + - Base Task: tasks/base_task.md + - Reconstruction Task: tasks/reconstruction_task.md + - Power Flow Task: tasks/power_flow.md + - Optimal Power Flow Task: tasks/optimal_power_flow.md + - State Estimation Task: tasks/state_estimation.md - Models: models/models.md - Training: - Losses: training/loss.md diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..be6a68c --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,60 @@ +""" +Session-scoped fixture that ensures the processed test data directory is +populated before any test that needs it runs. + +Specifically it: + 1. Runs LitGridHeteroDataModule.setup("fit") which triggers + HeteroGridDatasetDisk to write the ``processed/`` .pt files. + 2. Persists the fitted normalizer stats as + ``tests/data/case14_ieee/processed/data_stats_HeteroDataMVANormalizer.pt`` + so that test_edge_flows.py and test_simulate_measurements.py can load + them directly without needing a full DM setup. +""" + +import os + +import pytest +import torch +import yaml + +from gridfm_graphkit.datasets.hetero_powergrid_datamodule import LitGridHeteroDataModule +from gridfm_graphkit.datasets.normalizers import HeteroDataMVANormalizer +from gridfm_graphkit.io.param_handler import NestedNamespace + +_STATS_PATH = "tests/data/case14_ieee/processed/data_stats_HeteroDataMVANormalizer.pt" +_CONFIG_PATH = "tests/config/datamodule_test_base_config.yaml" + + +class _DummyTrainer: + """Minimal stand-in for a Lightning Trainer used only during test setup.""" + + is_global_zero = True + logger = None # prevents AttributeError in hetero_powergrid_datamodule.setup() + + +@pytest.fixture(scope="session", autouse=True) +def generate_processed_test_data(): + """ + Generate processed test data files that are needed by tests which load + them directly (test_edge_flows, test_simulate_measurements). + + Skipped silently if the stats file already exists (e.g., second pytest run + in the same environment without cleaning the processed/ directory). + """ + if os.path.exists(_STATS_PATH): + return + + with open(_CONFIG_PATH) as f: + config_dict = yaml.safe_load(f) + + args = NestedNamespace(**config_dict) + dm = LitGridHeteroDataModule(args, data_dir="tests/data") + dm.trainer = _DummyTrainer() + dm.setup("fit") + + # Persist the fitted normalizer stats under the name used by the tests. + normalizer = dm.data_normalizers[0] + assert isinstance(normalizer, HeteroDataMVANormalizer), ( + f"Expected HeteroDataMVANormalizer, got {type(normalizer).__name__}" + ) + torch.save(normalizer.get_stats(), _STATS_PATH)