diff --git a/.gitignore b/.gitignore index 6457578..3859d4b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ +*venv* +data_out *venv data_out logs @@ -14,4 +16,8 @@ gridfm_graphkit.egg-info mlruns *.pt .DS_Store +.julia +*logs* +*data_out* +site* .venv diff --git a/gridfm_graphkit/__main__.py b/gridfm_graphkit/__main__.py index c01d57a..63a72ce 100644 --- a/gridfm_graphkit/__main__.py +++ b/gridfm_graphkit/__main__.py @@ -11,6 +11,30 @@ def main(): subparsers = parser.add_subparsers(dest="command", required=True) exp_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + _compile_kwargs = dict( + type=str, + default=None, + nargs="?", + const="default", + choices=[ + "default", + "reduce-overhead", + "max-autotune", + "max-autotune-no-cudagraphs", + ], + help="Enable torch.compile with the given mode (omit value for 'default').", + ) + _bfloat16_kwargs = dict( + action="store_true", + default=False, + help="Cast model to bfloat16 (model.to(torch.bfloat16)).", + ) + _tf32_kwargs = dict( + action="store_true", + default=False, + help="Enable TF32 on Ampere+ GPUs via torch.set_float32_matmul_precision('high').", + ) + # ---- TRAIN SUBCOMMAND ---- train_parser = subparsers.add_parser("train", help="Run training") train_parser.add_argument("--config", type=str, required=True) @@ -18,6 +42,9 @@ def main(): train_parser.add_argument("--run_name", type=str, default="run") train_parser.add_argument("--log_dir", type=str, default="mlruns") train_parser.add_argument("--data_path", type=str, default="data") + train_parser.add_argument("--compile", **_compile_kwargs) + train_parser.add_argument("--bfloat16", **_bfloat16_kwargs) + train_parser.add_argument("--tf32", **_tf32_kwargs) train_parser.add_argument( "--dataset_wrapper", type=str, @@ -40,14 +67,14 @@ def main(): "--dataset_wrapper_cache_dir", type=str, default=None, - help="Directory for the dataset wrapper's disk cache.", + help="Directory for the dataset wrapper's disk cache. If set, cache is loaded from here when present and saved here after first population.", ) train_parser.add_argument( "--profiler", type=str, default=None, choices=["simple", "advanced", "pytorch"], - help="Enable Lightning profiler.", + help="Enable Lightning profiler: 'simple', 'advanced', or 'pytorch'.", ) # ---- FINETUNE SUBCOMMAND ---- @@ -58,36 +85,39 @@ def main(): finetune_parser.add_argument("--run_name", type=str, default="run") finetune_parser.add_argument("--log_dir", type=str, default="mlruns") finetune_parser.add_argument("--data_path", type=str, default="data") + finetune_parser.add_argument("--compile", **_compile_kwargs) + finetune_parser.add_argument("--bfloat16", **_bfloat16_kwargs) + finetune_parser.add_argument("--tf32", **_tf32_kwargs) finetune_parser.add_argument( "--dataset_wrapper", type=str, default=None, - help="Registered name of a dataset wrapper.", + help="Registered name of a dataset wrapper (see DATASET_WRAPPER_REGISTRY), e.g. SharedMemoryCacheDataset", ) finetune_parser.add_argument( "--plugins", nargs="*", default=[], - help="Python packages to import for plugin registration.", + help="Python packages to import for plugin registration, e.g. gridfm_graphkit_ee", ) finetune_parser.add_argument( "--num_workers", type=int, default=None, - help="Override data.workers from the YAML config.", + help="Override data.workers from the YAML config. Use 0 to debug worker crashes.", ) finetune_parser.add_argument( "--dataset_wrapper_cache_dir", type=str, default=None, - help="Directory for the dataset wrapper's disk cache.", + help="Directory for the dataset wrapper's disk cache. If set, cache is loaded from here when present and saved here after first population.", ) finetune_parser.add_argument( "--profiler", type=str, default=None, choices=["simple", "advanced", "pytorch"], - help="Enable Lightning profiler.", + help="Enable Lightning profiler: 'simple', 'advanced', or 'pytorch'.", ) # ---- EVALUATE SUBCOMMAND ---- @@ -107,36 +137,39 @@ def main(): evaluate_parser.add_argument("--run_name", type=str, default="run") evaluate_parser.add_argument("--log_dir", type=str, default="mlruns") evaluate_parser.add_argument("--data_path", type=str, default="data") + evaluate_parser.add_argument("--compile", **_compile_kwargs) + evaluate_parser.add_argument("--bfloat16", **_bfloat16_kwargs) + evaluate_parser.add_argument("--tf32", **_tf32_kwargs) evaluate_parser.add_argument( "--dataset_wrapper", type=str, default=None, - help="Registered name of a dataset wrapper.", + help="Registered name of a dataset wrapper (see DATASET_WRAPPER_REGISTRY), e.g. SharedMemoryCacheDataset", ) evaluate_parser.add_argument( "--plugins", nargs="*", default=[], - help="Python packages to import for plugin registration.", + help="Python packages to import for plugin registration, e.g. gridfm_graphkit_ee", ) evaluate_parser.add_argument( "--num_workers", type=int, default=None, - help="Override data.workers from the YAML config.", + help="Override data.workers from the YAML config. Use 0 to debug worker crashes.", ) evaluate_parser.add_argument( "--dataset_wrapper_cache_dir", type=str, default=None, - help="Directory for the dataset wrapper's disk cache.", + help="Directory for the dataset wrapper's disk cache. If set, cache is loaded from here when present and saved here after first population.", ) evaluate_parser.add_argument( "--profiler", type=str, default=None, choices=["simple", "advanced", "pytorch"], - help="Enable Lightning profiler.", + help="Enable Lightning profiler: 'simple', 'advanced', or 'pytorch'.", ) evaluate_parser.add_argument( "--compute_dc_ac_metrics", @@ -156,11 +189,34 @@ def main(): predict_parser.add_argument("--run_name", type=str, default="run") predict_parser.add_argument("--log_dir", type=str, default="mlruns") predict_parser.add_argument("--data_path", type=str, default="data") - predict_parser.add_argument("--dataset_wrapper", type=str, default=None) - predict_parser.add_argument("--plugins", nargs="*", default=[]) - predict_parser.add_argument("--num_workers", type=int, default=None) - predict_parser.add_argument("--dataset_wrapper_cache_dir", type=str, default=None) + predict_parser.add_argument( + "--dataset_wrapper", + type=str, + default=None, + help="Registered name of a dataset wrapper (see DATASET_WRAPPER_REGISTRY), e.g. SharedMemoryCacheDataset", + ) + predict_parser.add_argument( + "--plugins", + nargs="*", + default=[], + help="Python packages to import for plugin registration, e.g. gridfm_graphkit_ee", + ) + predict_parser.add_argument( + "--num_workers", + type=int, + default=None, + help="Override data.workers from the YAML config. Use 0 to debug worker crashes.", + ) + predict_parser.add_argument( + "--dataset_wrapper_cache_dir", + type=str, + default=None, + help="Directory for the dataset wrapper's disk cache. If set, cache is loaded from here when present and saved here after first population.", + ) predict_parser.add_argument("--output_path", type=str, default="data") + predict_parser.add_argument("--compile", **_compile_kwargs) + predict_parser.add_argument("--bfloat16", **_bfloat16_kwargs) + predict_parser.add_argument("--tf32", **_tf32_kwargs) predict_parser.add_argument( "--profiler", type=str, @@ -175,11 +231,36 @@ def main(): ) benchmark_parser.add_argument("--config", type=str, required=True) benchmark_parser.add_argument("--data_path", type=str, default="data") - benchmark_parser.add_argument("--epochs", type=int, default=3) - benchmark_parser.add_argument("--dataset_wrapper", type=str, default=None) - benchmark_parser.add_argument("--dataset_wrapper_cache_dir", type=str, default=None) - benchmark_parser.add_argument("--num_workers", type=int, default=None) - benchmark_parser.add_argument("--plugins", nargs="*", default=[]) + benchmark_parser.add_argument( + "--epochs", + type=int, + default=3, + help="Number of epochs to iterate through the train dataloader.", + ) + benchmark_parser.add_argument( + "--dataset_wrapper", + type=str, + default=None, + help="Registered name of a dataset wrapper (see DATASET_WRAPPER_REGISTRY), e.g. SharedMemoryCacheDataset", + ) + benchmark_parser.add_argument( + "--dataset_wrapper_cache_dir", + type=str, + default=None, + help="Directory for the dataset wrapper's disk cache.", + ) + benchmark_parser.add_argument( + "--num_workers", + type=int, + default=None, + help="Override data.workers from the YAML config.", + ) + benchmark_parser.add_argument( + "--plugins", + nargs="*", + default=[], + help="Python packages to import for plugin registration.", + ) args = parser.parse_args() @@ -190,4 +271,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/gridfm_graphkit/cli.py b/gridfm_graphkit/cli.py index 44ee1e3..015f7ed 100644 --- a/gridfm_graphkit/cli.py +++ b/gridfm_graphkit/cli.py @@ -124,6 +124,9 @@ def get_training_callbacks(args): def main_cli(args): + if getattr(args, "tf32", False): + torch.set_float32_matmul_precision("high") # enables TF32 on Ampere+ GPUs + logger = MLFlowLogger( save_dir=args.log_dir, experiment_name=args.exp_name, @@ -162,6 +165,25 @@ def main_cli(args): state_dict = torch.load(args.model_path, map_location="cpu") model.load_state_dict(state_dict) + precision = "bf16-true" if getattr(args, "bfloat16", False) else None + if precision: + print("Using bfloat16 precision (via Lightning Trainer precision='bf16-true')") + + compile_mode = getattr(args, "compile", None) + if compile_mode is not None: + if compile_mode in ("max-autotune", "max-autotune-no-cudagraphs"): + # Allow ATen GEMM as fallback so Triton configs that exceed GPU + # shared-memory limits (e.g. triton_mm OOM) are skipped gracefully + # instead of causing autotuning errors. + import torch._inductor.config as inductor_cfg + + inductor_cfg.max_autotune_gemm_backends = "ATEN,TRITON" + print(f"Compiling model with torch.compile(mode='{compile_mode}')") + model.model = torch.compile(model.model, mode=compile_mode) + + trainer_kwargs = {} + if precision: + trainer_kwargs["precision"] = precision profiler = getattr(args, "profiler", None) trainer = L.Trainer( @@ -173,6 +195,7 @@ def main_cli(args): default_root_dir=args.log_dir, max_epochs=config_args.training.epochs, callbacks=get_training_callbacks(config_args), + **trainer_kwargs, profiler=profiler, ) if args.command == "train" or args.command == "finetune": @@ -186,6 +209,7 @@ def main_cli(args): num_nodes=1, log_every_n_steps=1, default_root_dir=args.log_dir, + **trainer_kwargs, profiler=profiler, ) test_trainer.test(model=model, datamodule=litGrid) @@ -219,6 +243,7 @@ def main_cli(args): num_nodes=1, log_every_n_steps=1, default_root_dir=args.log_dir, + **trainer_kwargs, profiler=profiler, ) predictions = predict_trainer.predict(model=model, datamodule=litGrid) diff --git a/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py b/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py index acd45aa..54db39c 100644 --- a/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py +++ b/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py @@ -369,6 +369,19 @@ def _dataloader_kwargs(self): pin_memory=torch.cuda.is_available(), persistent_workers=num_workers > 0, ) + # On Linux some HPC environments restrict passing open file descriptors + # via Unix socket ancillary data (SCM_RIGHTS), which causes + # "received 0 items of ancdata" with the default 'fork' start method. + # 'forkserver' avoids fd-passing by having a dedicated server process + # that re-opens shared memory objects by name instead. + if ( + num_workers > 0 + and torch.multiprocessing.get_start_method(allow_none=True) != "spawn" + ): + import platform + + if platform.system() == "Linux": + kwargs["multiprocessing_context"] = "forkserver" return kwargs def train_dataloader(self): diff --git a/gridfm_graphkit/models/utils.py b/gridfm_graphkit/models/utils.py index 46e75c7..ea4ecaf 100644 --- a/gridfm_graphkit/models/utils.py +++ b/gridfm_graphkit/models/utils.py @@ -7,7 +7,6 @@ # Bus feature indices PD_H, QD_H, - QG_H, GS, BS, # Output feature indices @@ -98,10 +97,10 @@ def forward(self, P_in, Q_in, bus_data_pred, bus_data_orig, agg_bus, mask_dict): # Qg = Q_in + Qd - q_shunt Qg_physics = Q_in + Qd - q_shunt - Qg_new = torch.zeros_like(bus_data_orig[:, QG_H]) - - # PV + REF: solve from physics - Qg_new[mask_pvref] = Qg_physics[mask_pvref] + # Use torch.where instead of boolean index-put to avoid aten.nonzero + # (data-dependent shape) which causes inductor graph breaks under + # torch.compile. + Qg_new = torch.where(mask_pvref, Qg_physics, torch.zeros_like(Qg_physics)) Pg_out = agg_bus # Active generation (Pg) Qg_out = Qg_new # Reactive gen (Qg) Vm_out = bus_data_pred[:, VM_OUT] # Voltage magnitude @@ -139,17 +138,16 @@ def forward(self, P_in, Q_in, bus_data_pred, bus_data_orig, agg_bus, mask_dict): # ====================== # Qg (PV + REF) # ====================== - Qg_new = torch.zeros_like(bus_data_orig[:, QG_H]) # PQ buses = 0 - Qg_new[mask_pvref] = Q_in[mask_pvref] + Qd[mask_pvref] - q_shunt[mask_pvref] + # Use torch.where instead of boolean index-put to avoid aten.nonzero + # (data-dependent shape) which causes inductor graph breaks under + # torch.compile. + Qg_new = torch.where(mask_pvref, Q_in + Qd - q_shunt, torch.zeros_like(Q_in)) # ====================== # Pg (REF only) # ====================== - Pg_new = torch.zeros_like(bus_data_orig[:, QG_H]) # PQ buses = 0 - Pg_new[mask_pv] = agg_bus[mask_pv] # PV: keep predicted - Pg_new[mask_ref] = ( - P_in[mask_ref] + Pd[mask_ref] - p_shunt[mask_ref] - ) # REF: balance + Pg_ref = torch.where(mask_ref, P_in + Pd - p_shunt, torch.zeros_like(P_in)) + Pg_new = torch.where(mask_pv, agg_bus, Pg_ref) # PV: keep predicted # Voltages Vm_out = bus_data_pred[:, VM_OUT] diff --git a/gridfm_graphkit/tasks/base_task.py b/gridfm_graphkit/tasks/base_task.py index ec75ccc..90c8f7b 100644 --- a/gridfm_graphkit/tasks/base_task.py +++ b/gridfm_graphkit/tasks/base_task.py @@ -1,4 +1,5 @@ import os +import time from abc import ABC, abstractmethod import lightning as L from pytorch_lightning.utilities import rank_zero_only @@ -19,11 +20,37 @@ def __init__(self, args, data_normalizers): self.data_normalizers = data_normalizers self.save_hyperparameters() + def on_after_batch_transfer(self, batch, dataloader_idx: int): + """Cast float tensors in HeteroData batches to the model's parameter dtype. + + Lightning's automatic mixed-precision casting does not handle PyG + HeteroData objects, so we do it manually here to avoid dtype mismatches + when --bfloat16 (precision='bf16-true') is used. + """ + if not hasattr(self, "model"): + return batch + try: + target_dtype = next(self.model.parameters()).dtype + except StopIteration: + return batch + if target_dtype == torch.float32: + # No casting needed for the default precision. + return batch + # Walk all node- and edge-store tensors in a HeteroData/Data object. + for store in batch.stores: + for key, val in store.items(): + if isinstance(val, torch.Tensor) and val.is_floating_point(): + store[key] = val.to(target_dtype) + return batch + @abstractmethod def forward(self, *args, **kwargs): """Forward pass""" pass + def on_train_batch_start(self, batch, batch_idx): + self._batch_start_time = time.perf_counter() + @abstractmethod def training_step(self, batch): pass diff --git a/tests/util/test_util_implementation_equivalence.py b/tests/util/test_util_implementation_equivalence.py new file mode 100644 index 0000000..3a9b68a --- /dev/null +++ b/tests/util/test_util_implementation_equivalence.py @@ -0,0 +1,51 @@ +import torch +import pytest + + +def impl_a(P_in, Pd, p_shunt, agg_bus, mask_pv, mask_ref): + Pg_new = torch.zeros_like(P_in) + Pg_new[mask_pv] = agg_bus[mask_pv] + Pg_new[mask_ref] = P_in[mask_ref] + Pd[mask_ref] - p_shunt[mask_ref] + return Pg_new + + +def impl_b(P_in, Pd, p_shunt, agg_bus, mask_pv, mask_ref): + Pg_ref = torch.where(mask_ref, P_in + Pd - p_shunt, torch.zeros_like(P_in)) + Pg_new = torch.where(mask_pv, agg_bus, Pg_ref) + return Pg_new + + +@pytest.mark.parametrize("allow_overlap", [False, True]) +def test_run(allow_overlap): + n = 10000 + device = "cpu" + torch.manual_seed(0) + + P_in = torch.randn(n, device=device) + Pd = torch.randn(n, device=device) + p_shunt = torch.randn(n, device=device) + agg_bus = torch.randn(n, device=device) + + # random masks + mask_pv = torch.rand(n, device=device) > 0.7 + mask_ref = torch.rand(n, device=device) > 0.7 + + if not allow_overlap: + mask_ref = mask_ref & (~mask_pv) # enforce disjointness + + out_a = impl_a(P_in, Pd, p_shunt, agg_bus, mask_pv, mask_ref) + out_b = impl_b(P_in, Pd, p_shunt, agg_bus, mask_pv, mask_ref) + + equal = torch.allclose(out_a, out_b) + max_diff = (out_a - out_b).abs().max().item() + + print(equal, max_diff) + if not allow_overlap: + assert equal, f"Outputs differ! Max abs diff: {max_diff:.6f}" + else: + assert max_diff > 0, ( + "Outputs are identical even with overlapping masks, which is unexpected!" + ) + assert not equal, ( + "Outputs are identical despite overlapping masks, which is unexpected!" + )