Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
ec01b3c
fix gitignore
romeokienzler Mar 26, 2026
2d103b2
add torch.compile support via --compile CLI flag
romeokienzler Mar 26, 2026
6b8aed9
fix torch.compile max-autotune OOM: allow ATEN GEMM fallback
romeokienzler Mar 26, 2026
b3fd8c4
fix torch.compile graph breaks: replace boolean index-put with torch.…
romeokienzler Mar 26, 2026
84e613f
add -tf32 and --bfloat16 parameters
romeokienzler Mar 30, 2026
2bb6f13
pass bfloat to lightning trainer as well if set
romeokienzler Mar 30, 2026
6dae6d5
enable autocast for HeteroData objects
romeokienzler Mar 30, 2026
c9fc63d
add smoke test
romeokienzler Mar 11, 2026
86fb524
fix value range
romeokienzler Mar 11, 2026
e2835e3
fix doc, improve test
romeokienzler Mar 13, 2026
6a15194
add missing doc
romeokienzler Mar 13, 2026
6735e77
improve test
romeokienzler Mar 13, 2026
5abe149
fix docs
romeokienzler Mar 14, 2026
345167b
fix ordering
romeokienzler Mar 14, 2026
c02271c
add fixure for cleanup
romeokienzler Mar 14, 2026
bed1e9e
fix test name
romeokienzler Mar 14, 2026
512dfc3
remove bad doc
romeokienzler Mar 15, 2026
2f557ae
fix broken links
romeokienzler Mar 15, 2026
9bf689d
fix documentation
romeokienzler Mar 15, 2026
afb569b
strip down
romeokienzler Mar 16, 2026
69d35ea
simplify doc
romeokienzler Mar 16, 2026
1e032be
add support for dataset wrapper
romeokienzler Mar 17, 2026
27f6c10
swap call order to support wrapping of ds
romeokienzler Mar 18, 2026
87e3e11
fix shared memroy file handler issue
romeokienzler Mar 18, 2026
73bcdcb
fix race condition
romeokienzler Mar 18, 2026
886cf51
add logging for debugging race condition
romeokienzler Mar 18, 2026
848fb69
add registry support for (3rd party) dataset wrapper
romeokienzler Mar 18, 2026
6f25a99
remove debug code, add parameter
romeokienzler Mar 18, 2026
7f89a3a
remove unnecessary param
romeokienzler Mar 19, 2026
73c5552
fix chache buildup order
romeokienzler Mar 19, 2026
633d797
changed introduced by precommit
romeokienzler Mar 19, 2026
985cae5
manual precommit fixes
romeokienzler Mar 19, 2026
abd8db6
fix precommit hook
romeokienzler Mar 19, 2026
a93f2ba
precommit fix
romeokienzler Mar 19, 2026
abacdc4
bump trivy
romeokienzler Mar 19, 2026
d49c8d1
fix precommit
romeokienzler Mar 19, 2026
2691248
fix trivy
romeokienzler Mar 19, 2026
d702616
add support for dataloading performance tests
romeokienzler Mar 24, 2026
8954822
fix validation order
romeokienzler Mar 24, 2026
698c8a4
precommit fix
romeokienzler Mar 25, 2026
0a0c5db
security fix
romeokienzler Mar 25, 2026
5de3045
fix missing package and circular import
romeokienzler Mar 25, 2026
7651ad6
ignore security CVE-2026-4539 as not relevant
romeokienzler Mar 25, 2026
1684ffa
fix tests
romeokienzler Mar 25, 2026
d62f5bf
fix precommit
romeokienzler Mar 25, 2026
7ff6c4c
add profiler cli argunent
romeokienzler Mar 20, 2026
7c1c245
add profiler cli argunent
romeokienzler Mar 20, 2026
cb80c4c
add batch elapsed time logging
romeokienzler Mar 30, 2026
6e3e6a2
test for semantic equivalenze emprirically
romeokienzler Apr 1, 2026
6d95793
implement in aten safe way
romeokienzler Apr 1, 2026
7153039
log (to mlflow) only every 1000 steps
romeokienzler Apr 1, 2026
76e6bd5
fix syntax error
romeokienzler Apr 2, 2026
b793b67
Merge branch 'main' into add_perf_fix
romeokienzler Apr 2, 2026
1908ef5
fix merge error
romeokienzler Apr 2, 2026
0dfa6e3
fix
romeokienzler Apr 2, 2026
e092dbd
remove crashing code
romeokienzler Apr 2, 2026
775be84
fix precommit hook
romeokienzler Apr 2, 2026
458eecc
fix received 0 items of ancdata
romeokienzler Apr 2, 2026
51ca57f
fix precommit
romeokienzler Apr 2, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
*venv*
data_out
*venv
data_out
logs
Expand All @@ -14,4 +16,8 @@ gridfm_graphkit.egg-info
mlruns
*.pt
.DS_Store
.julia
*logs*
*data_out*
site*
.venv
125 changes: 103 additions & 22 deletions gridfm_graphkit/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,40 @@ def main():
subparsers = parser.add_subparsers(dest="command", required=True)
exp_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

_compile_kwargs = dict(
type=str,
default=None,
nargs="?",
const="default",
choices=[
"default",
"reduce-overhead",
"max-autotune",
"max-autotune-no-cudagraphs",
],
help="Enable torch.compile with the given mode (omit value for 'default').",
)
_bfloat16_kwargs = dict(
action="store_true",
default=False,
help="Cast model to bfloat16 (model.to(torch.bfloat16)).",
)
_tf32_kwargs = dict(
action="store_true",
default=False,
help="Enable TF32 on Ampere+ GPUs via torch.set_float32_matmul_precision('high').",
)

# ---- TRAIN SUBCOMMAND ----
train_parser = subparsers.add_parser("train", help="Run training")
train_parser.add_argument("--config", type=str, required=True)
train_parser.add_argument("--exp_name", type=str, default=exp_name)
train_parser.add_argument("--run_name", type=str, default="run")
train_parser.add_argument("--log_dir", type=str, default="mlruns")
train_parser.add_argument("--data_path", type=str, default="data")
train_parser.add_argument("--compile", **_compile_kwargs)
train_parser.add_argument("--bfloat16", **_bfloat16_kwargs)
train_parser.add_argument("--tf32", **_tf32_kwargs)
train_parser.add_argument(
"--dataset_wrapper",
type=str,
Expand All @@ -40,14 +67,14 @@ def main():
"--dataset_wrapper_cache_dir",
type=str,
default=None,
help="Directory for the dataset wrapper's disk cache.",
help="Directory for the dataset wrapper's disk cache. If set, cache is loaded from here when present and saved here after first population.",
)
train_parser.add_argument(
"--profiler",
type=str,
default=None,
choices=["simple", "advanced", "pytorch"],
help="Enable Lightning profiler.",
help="Enable Lightning profiler: 'simple', 'advanced', or 'pytorch'.",
)

# ---- FINETUNE SUBCOMMAND ----
Expand All @@ -58,36 +85,39 @@ def main():
finetune_parser.add_argument("--run_name", type=str, default="run")
finetune_parser.add_argument("--log_dir", type=str, default="mlruns")
finetune_parser.add_argument("--data_path", type=str, default="data")
finetune_parser.add_argument("--compile", **_compile_kwargs)
finetune_parser.add_argument("--bfloat16", **_bfloat16_kwargs)
finetune_parser.add_argument("--tf32", **_tf32_kwargs)
finetune_parser.add_argument(
"--dataset_wrapper",
type=str,
default=None,
help="Registered name of a dataset wrapper.",
help="Registered name of a dataset wrapper (see DATASET_WRAPPER_REGISTRY), e.g. SharedMemoryCacheDataset",
)
finetune_parser.add_argument(
"--plugins",
nargs="*",
default=[],
help="Python packages to import for plugin registration.",
help="Python packages to import for plugin registration, e.g. gridfm_graphkit_ee",
)
finetune_parser.add_argument(
"--num_workers",
type=int,
default=None,
help="Override data.workers from the YAML config.",
help="Override data.workers from the YAML config. Use 0 to debug worker crashes.",
)
finetune_parser.add_argument(
"--dataset_wrapper_cache_dir",
type=str,
default=None,
help="Directory for the dataset wrapper's disk cache.",
help="Directory for the dataset wrapper's disk cache. If set, cache is loaded from here when present and saved here after first population.",
)
finetune_parser.add_argument(
"--profiler",
type=str,
default=None,
choices=["simple", "advanced", "pytorch"],
help="Enable Lightning profiler.",
help="Enable Lightning profiler: 'simple', 'advanced', or 'pytorch'.",
)

# ---- EVALUATE SUBCOMMAND ----
Expand All @@ -107,36 +137,39 @@ def main():
evaluate_parser.add_argument("--run_name", type=str, default="run")
evaluate_parser.add_argument("--log_dir", type=str, default="mlruns")
evaluate_parser.add_argument("--data_path", type=str, default="data")
evaluate_parser.add_argument("--compile", **_compile_kwargs)
evaluate_parser.add_argument("--bfloat16", **_bfloat16_kwargs)
evaluate_parser.add_argument("--tf32", **_tf32_kwargs)
evaluate_parser.add_argument(
"--dataset_wrapper",
type=str,
default=None,
help="Registered name of a dataset wrapper.",
help="Registered name of a dataset wrapper (see DATASET_WRAPPER_REGISTRY), e.g. SharedMemoryCacheDataset",
)
evaluate_parser.add_argument(
"--plugins",
nargs="*",
default=[],
help="Python packages to import for plugin registration.",
help="Python packages to import for plugin registration, e.g. gridfm_graphkit_ee",
)
evaluate_parser.add_argument(
"--num_workers",
type=int,
default=None,
help="Override data.workers from the YAML config.",
help="Override data.workers from the YAML config. Use 0 to debug worker crashes.",
)
evaluate_parser.add_argument(
"--dataset_wrapper_cache_dir",
type=str,
default=None,
help="Directory for the dataset wrapper's disk cache.",
help="Directory for the dataset wrapper's disk cache. If set, cache is loaded from here when present and saved here after first population.",
)
evaluate_parser.add_argument(
"--profiler",
type=str,
default=None,
choices=["simple", "advanced", "pytorch"],
help="Enable Lightning profiler.",
help="Enable Lightning profiler: 'simple', 'advanced', or 'pytorch'.",
)
evaluate_parser.add_argument(
"--compute_dc_ac_metrics",
Expand All @@ -156,11 +189,34 @@ def main():
predict_parser.add_argument("--run_name", type=str, default="run")
predict_parser.add_argument("--log_dir", type=str, default="mlruns")
predict_parser.add_argument("--data_path", type=str, default="data")
predict_parser.add_argument("--dataset_wrapper", type=str, default=None)
predict_parser.add_argument("--plugins", nargs="*", default=[])
predict_parser.add_argument("--num_workers", type=int, default=None)
predict_parser.add_argument("--dataset_wrapper_cache_dir", type=str, default=None)
predict_parser.add_argument(
"--dataset_wrapper",
type=str,
default=None,
help="Registered name of a dataset wrapper (see DATASET_WRAPPER_REGISTRY), e.g. SharedMemoryCacheDataset",
)
predict_parser.add_argument(
"--plugins",
nargs="*",
default=[],
help="Python packages to import for plugin registration, e.g. gridfm_graphkit_ee",
)
predict_parser.add_argument(
"--num_workers",
type=int,
default=None,
help="Override data.workers from the YAML config. Use 0 to debug worker crashes.",
)
predict_parser.add_argument(
"--dataset_wrapper_cache_dir",
type=str,
default=None,
help="Directory for the dataset wrapper's disk cache. If set, cache is loaded from here when present and saved here after first population.",
)
predict_parser.add_argument("--output_path", type=str, default="data")
predict_parser.add_argument("--compile", **_compile_kwargs)
predict_parser.add_argument("--bfloat16", **_bfloat16_kwargs)
predict_parser.add_argument("--tf32", **_tf32_kwargs)
predict_parser.add_argument(
"--profiler",
type=str,
Expand All @@ -175,11 +231,36 @@ def main():
)
benchmark_parser.add_argument("--config", type=str, required=True)
benchmark_parser.add_argument("--data_path", type=str, default="data")
benchmark_parser.add_argument("--epochs", type=int, default=3)
benchmark_parser.add_argument("--dataset_wrapper", type=str, default=None)
benchmark_parser.add_argument("--dataset_wrapper_cache_dir", type=str, default=None)
benchmark_parser.add_argument("--num_workers", type=int, default=None)
benchmark_parser.add_argument("--plugins", nargs="*", default=[])
benchmark_parser.add_argument(
"--epochs",
type=int,
default=3,
help="Number of epochs to iterate through the train dataloader.",
)
benchmark_parser.add_argument(
"--dataset_wrapper",
type=str,
default=None,
help="Registered name of a dataset wrapper (see DATASET_WRAPPER_REGISTRY), e.g. SharedMemoryCacheDataset",
)
benchmark_parser.add_argument(
"--dataset_wrapper_cache_dir",
type=str,
default=None,
help="Directory for the dataset wrapper's disk cache.",
)
benchmark_parser.add_argument(
"--num_workers",
type=int,
default=None,
help="Override data.workers from the YAML config.",
)
benchmark_parser.add_argument(
"--plugins",
nargs="*",
default=[],
help="Python packages to import for plugin registration.",
)

args = parser.parse_args()

Expand All @@ -190,4 +271,4 @@ def main():


if __name__ == "__main__":
main()
main()
25 changes: 25 additions & 0 deletions gridfm_graphkit/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ def get_training_callbacks(args):


def main_cli(args):
if getattr(args, "tf32", False):
torch.set_float32_matmul_precision("high") # enables TF32 on Ampere+ GPUs

logger = MLFlowLogger(
save_dir=args.log_dir,
experiment_name=args.exp_name,
Expand Down Expand Up @@ -162,6 +165,25 @@ def main_cli(args):
state_dict = torch.load(args.model_path, map_location="cpu")
model.load_state_dict(state_dict)

precision = "bf16-true" if getattr(args, "bfloat16", False) else None
if precision:
print("Using bfloat16 precision (via Lightning Trainer precision='bf16-true')")

compile_mode = getattr(args, "compile", None)
if compile_mode is not None:
if compile_mode in ("max-autotune", "max-autotune-no-cudagraphs"):
# Allow ATen GEMM as fallback so Triton configs that exceed GPU
# shared-memory limits (e.g. triton_mm OOM) are skipped gracefully
# instead of causing autotuning errors.
import torch._inductor.config as inductor_cfg

inductor_cfg.max_autotune_gemm_backends = "ATEN,TRITON"
print(f"Compiling model with torch.compile(mode='{compile_mode}')")
model.model = torch.compile(model.model, mode=compile_mode)

trainer_kwargs = {}
if precision:
trainer_kwargs["precision"] = precision
profiler = getattr(args, "profiler", None)

trainer = L.Trainer(
Expand All @@ -173,6 +195,7 @@ def main_cli(args):
default_root_dir=args.log_dir,
max_epochs=config_args.training.epochs,
callbacks=get_training_callbacks(config_args),
**trainer_kwargs,
profiler=profiler,
)
if args.command == "train" or args.command == "finetune":
Expand All @@ -186,6 +209,7 @@ def main_cli(args):
num_nodes=1,
log_every_n_steps=1,
default_root_dir=args.log_dir,
**trainer_kwargs,
profiler=profiler,
)
test_trainer.test(model=model, datamodule=litGrid)
Expand Down Expand Up @@ -219,6 +243,7 @@ def main_cli(args):
num_nodes=1,
log_every_n_steps=1,
default_root_dir=args.log_dir,
**trainer_kwargs,
profiler=profiler,
)
predictions = predict_trainer.predict(model=model, datamodule=litGrid)
Expand Down
13 changes: 13 additions & 0 deletions gridfm_graphkit/datasets/hetero_powergrid_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,19 @@ def _dataloader_kwargs(self):
pin_memory=torch.cuda.is_available(),
persistent_workers=num_workers > 0,
)
# On Linux some HPC environments restrict passing open file descriptors
# via Unix socket ancillary data (SCM_RIGHTS), which causes
# "received 0 items of ancdata" with the default 'fork' start method.
# 'forkserver' avoids fd-passing by having a dedicated server process
# that re-opens shared memory objects by name instead.
if (
num_workers > 0
and torch.multiprocessing.get_start_method(allow_none=True) != "spawn"
):
import platform

if platform.system() == "Linux":
kwargs["multiprocessing_context"] = "forkserver"
return kwargs

def train_dataloader(self):
Expand Down
22 changes: 10 additions & 12 deletions gridfm_graphkit/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
# Bus feature indices
PD_H,
QD_H,
QG_H,
GS,
BS,
# Output feature indices
Expand Down Expand Up @@ -98,10 +97,10 @@ def forward(self, P_in, Q_in, bus_data_pred, bus_data_orig, agg_bus, mask_dict):
# Qg = Q_in + Qd - q_shunt
Qg_physics = Q_in + Qd - q_shunt

Qg_new = torch.zeros_like(bus_data_orig[:, QG_H])

# PV + REF: solve from physics
Qg_new[mask_pvref] = Qg_physics[mask_pvref]
# Use torch.where instead of boolean index-put to avoid aten.nonzero
# (data-dependent shape) which causes inductor graph breaks under
# torch.compile.
Qg_new = torch.where(mask_pvref, Qg_physics, torch.zeros_like(Qg_physics))
Pg_out = agg_bus # Active generation (Pg)
Qg_out = Qg_new # Reactive gen (Qg)
Vm_out = bus_data_pred[:, VM_OUT] # Voltage magnitude
Expand Down Expand Up @@ -139,17 +138,16 @@ def forward(self, P_in, Q_in, bus_data_pred, bus_data_orig, agg_bus, mask_dict):
# ======================
# Qg (PV + REF)
# ======================
Qg_new = torch.zeros_like(bus_data_orig[:, QG_H]) # PQ buses = 0
Qg_new[mask_pvref] = Q_in[mask_pvref] + Qd[mask_pvref] - q_shunt[mask_pvref]
# Use torch.where instead of boolean index-put to avoid aten.nonzero
# (data-dependent shape) which causes inductor graph breaks under
# torch.compile.
Qg_new = torch.where(mask_pvref, Q_in + Qd - q_shunt, torch.zeros_like(Q_in))

# ======================
# Pg (REF only)
# ======================
Pg_new = torch.zeros_like(bus_data_orig[:, QG_H]) # PQ buses = 0
Pg_new[mask_pv] = agg_bus[mask_pv] # PV: keep predicted
Pg_new[mask_ref] = (
P_in[mask_ref] + Pd[mask_ref] - p_shunt[mask_ref]
) # REF: balance
Pg_ref = torch.where(mask_ref, P_in + Pd - p_shunt, torch.zeros_like(P_in))
Pg_new = torch.where(mask_pv, agg_bus, Pg_ref) # PV: keep predicted

# Voltages
Vm_out = bus_data_pred[:, VM_OUT]
Expand Down
Loading
Loading