gridfm-graphkit
This library is brought to you by the GridFM team to train, finetune and interact with a foundation model for the electric power grid.
Create and activate a virtual environment (make sure you use the right python version = 3.10, 3.11 or 3.12. I highly recommend 3.12)
python -m venv venv
source venv/bin/activateInstall gridfm-graphkit in editable mode
pip install -e .Get PyTorch + CUDA version for torch-scatter
TORCH_CUDA_VERSION=$(python -c "import torch; print(torch.__version__ + ('+cpu' if torch.version.cuda is None else ''))")Install the correct torch-scatter wheel
pip install torch-scatter -f https://data.pyg.org/whl/torch-${TORCH_CUDA_VERSION}.htmlFor documentation generation and unit testing, install with the optional dev and test extras:
pip install -e .[dev,test]Interface to train, fine-tune, evaluate, and run inference on GridFM models using YAML configs and MLflow tracking.
gridfm_graphkit <command> [OPTIONS]Available commands:
train- Train a new model from scratchfinetune- Fine-tune an existing pre-trained modelevaluate- Evaluate model performance on a datasetpredict- Run inference and save predictions
gridfm_graphkit train --config path/to/config.yaml| Argument | Type | Description | Default |
|---|---|---|---|
--config |
str |
Required. Path to the training configuration YAML file. | None |
--exp_name |
str |
MLflow experiment name. | timestamp |
--run_name |
str |
MLflow run name. | run |
--log_dir |
str |
MLflow tracking/logging directory. | mlruns |
--data_path |
str |
Root dataset directory. | data |
--compile [MODE] |
str |
Enable torch.compile mode. Valid values: default, reduce-overhead, max-autotune, max-autotune-no-cudagraphs. If flag is passed without a value, mode is default. |
None |
--bfloat16 |
flag |
Cast model to torch.bfloat16 (model.to(torch.bfloat16)). |
False |
--tf32 |
flag |
Enable TF32 on Ampere+ GPUs via torch.set_float32_matmul_precision("high"). |
False |
--dataset_wrapper |
str |
Registered dataset wrapper name (see DATASET_WRAPPER_REGISTRY), e.g. SharedMemoryCacheDataset. |
None |
--plugins |
list[str] |
Python packages to import for plugin registration, e.g. gridfm_graphkit_ee. |
[] |
--num_workers |
int |
Override data.workers from YAML. Use 0 to debug worker crashes. |
None |
--dataset_wrapper_cache_dir |
str |
Disk cache directory for dataset wrapper; cache is loaded from here when present and saved after first population. | None |
--profiler |
str |
Enable Lightning profiler (simple, advanced, pytorch). |
None |
--compute_dc_ac_metrics |
flag |
Compute ground-truth AC/DC power balance metrics on the test split. | False |
Standard Training:
gridfm_graphkit train --config examples/config/case30_ieee_base.yaml --data_path examples/datagridfm_graphkit finetune --config path/to/config.yaml --model_path path/to/model.pt| Argument | Type | Description | Default |
|---|---|---|---|
--config |
str |
Required. Fine-tuning configuration file. | None |
--model_path |
str |
Required. Path to a pre-trained model state dict. | None |
--exp_name |
str |
MLflow experiment name. | timestamp |
--run_name |
str |
MLflow run name. | run |
--log_dir |
str |
MLflow logging directory. | mlruns |
--data_path |
str |
Root dataset directory. | data |
--compile [MODE] |
str |
Enable torch.compile mode. Valid values: default, reduce-overhead, max-autotune, max-autotune-no-cudagraphs. If flag is passed without a value, mode is default. |
None |
--bfloat16 |
flag |
Cast model to torch.bfloat16 (model.to(torch.bfloat16)). |
False |
--tf32 |
flag |
Enable TF32 on Ampere+ GPUs via torch.set_float32_matmul_precision("high"). |
False |
--dataset_wrapper |
str |
Registered dataset wrapper name (see DATASET_WRAPPER_REGISTRY), e.g. SharedMemoryCacheDataset. |
None |
--plugins |
list[str] |
Python packages to import for plugin registration, e.g. gridfm_graphkit_ee. |
[] |
--num_workers |
int |
Override data.workers from YAML. Use 0 to debug worker crashes. |
None |
--dataset_wrapper_cache_dir |
str |
Disk cache directory for dataset wrapper; cache is loaded from here when present and saved after first population. | None |
--profiler |
str |
Enable Lightning profiler (simple, advanced, pytorch). |
None |
--compute_dc_ac_metrics |
flag |
Compute ground-truth AC/DC power balance metrics on the test split. | False |
gridfm_graphkit evaluate --config path/to/eval.yaml --model_path path/to/model.pt| Argument | Type | Description | Default |
|---|---|---|---|
--config |
str |
Required. Path to evaluation config. | None |
--model_path |
str |
Path to the trained model state dict. | None |
--normalizer_stats |
str |
Path to normalizer_stats.pt from a training run. Restores fit_on_train normalizers from saved statistics instead of re-fitting on current split. |
None |
--exp_name |
str |
MLflow experiment name. | timestamp |
--run_name |
str |
MLflow run name. | run |
--log_dir |
str |
MLflow logging directory. | mlruns |
--data_path |
str |
Dataset directory. | data |
--compile [MODE] |
str |
Enable torch.compile mode. Valid values: default, reduce-overhead, max-autotune, max-autotune-no-cudagraphs. If flag is passed without a value, mode is default. |
None |
--bfloat16 |
flag |
Cast model to torch.bfloat16 (model.to(torch.bfloat16)). |
False |
--tf32 |
flag |
Enable TF32 on Ampere+ GPUs via torch.set_float32_matmul_precision("high"). |
False |
--dataset_wrapper |
str |
Registered dataset wrapper name (see DATASET_WRAPPER_REGISTRY), e.g. SharedMemoryCacheDataset. |
None |
--plugins |
list[str] |
Python packages to import for plugin registration, e.g. gridfm_graphkit_ee. |
[] |
--num_workers |
int |
Override data.workers from YAML. Use 0 to debug worker crashes. |
None |
--dataset_wrapper_cache_dir |
str |
Disk cache directory for dataset wrapper; cache is loaded from here when present and saved after first population. | None |
--profiler |
str |
Enable Lightning profiler (simple, advanced, pytorch). |
None |
--compute_dc_ac_metrics |
flag |
Compute ground-truth AC/DC power balance metrics on the test split. | False |
--save_output |
flag |
Save predictions as <grid_name>_predictions.parquet under MLflow artifacts (.../artifacts/test). |
False |
When evaluating a model on a dataset, you can pass the normalizer statistics from the original training run to ensure the same normalization parameters are used:
gridfm_graphkit evaluate \
--config examples/config/HGNS_PF_datakit_case118.yaml \
--model_path mlruns/<experiment_id>/<run_id>/artifacts/model/best_model_state_dict.pt \
--normalizer_stats mlruns/<experiment_id>/<run_id>/artifacts/stats/normalizer_stats.pt \
--data_path dataNote: The
--normalizer_statsflag only affects normalizers withfit_strategy = "fit_on_train"(e.g.HeteroDataMVANormalizer). Per-sample normalizers (HeteroDataPerSampleMVANormalizer) always recompute their statistics from the current dataset regardless of this flag.
gridfm_graphkit predict --config path/to/config.yaml --model_path path/to/model.pt| Argument | Type | Description | Default |
|---|---|---|---|
--config |
str |
Required. Path to prediction config file. | None |
--model_path |
str |
Path to trained model state dict. Optional; may be defined in config. | None |
--normalizer_stats |
str |
Path to normalizer_stats.pt from a training run. Restores fit_on_train normalizers from saved statistics. |
None |
--exp_name |
str |
MLflow experiment name. | timestamp |
--run_name |
str |
MLflow run name. | run |
--log_dir |
str |
MLflow logging directory. | mlruns |
--data_path |
str |
Dataset directory. | data |
--dataset_wrapper |
str |
Registered dataset wrapper name (see DATASET_WRAPPER_REGISTRY), e.g. SharedMemoryCacheDataset. |
None |
--plugins |
list[str] |
Python packages to import for plugin registration, e.g. gridfm_graphkit_ee. |
[] |
--num_workers |
int |
Override data.workers from YAML. Use 0 to debug worker crashes. |
None |
--dataset_wrapper_cache_dir |
str |
Disk cache directory for dataset wrapper; cache is loaded from here when present and saved after first population. | None |
--output_path |
str |
Directory where predictions are saved as <grid_name>_predictions.parquet. |
data |
--compile [MODE] |
str |
Enable torch.compile mode. Valid values: default, reduce-overhead, max-autotune, max-autotune-no-cudagraphs. If flag is passed without a value, mode is default. |
None |
--bfloat16 |
flag |
Cast model to torch.bfloat16 (model.to(torch.bfloat16)). |
False |
--tf32 |
flag |
Enable TF32 on Ampere+ GPUs via torch.set_float32_matmul_precision("high"). |
False |
--profiler |
str |
Enable Lightning profiler (simple, advanced, pytorch). |
None |
gridfm_graphkit benchmark --config path/to/config.yaml| Argument | Type | Description | Default |
|---|---|---|---|
--config |
str |
Required. Path to configuration YAML file. | None |
--data_path |
str |
Root dataset directory. | data |
--epochs |
int |
Number of epochs to iterate through the train dataloader. | 3 |
--dataset_wrapper |
str |
Registered dataset wrapper name (see DATASET_WRAPPER_REGISTRY), e.g. SharedMemoryCacheDataset. |
None |
--dataset_wrapper_cache_dir |
str |
Directory for dataset wrapper disk cache. | None |
--num_workers |
int |
Override data.workers from YAML. |
None |
--plugins |
list[str] |
Python packages to import for plugin registration. | [] |
Use built-in help for full command details:
gridfm_graphkit --help
gridfm_graphkit <command> --help