diff --git a/.gitignore b/.gitignore index c2ade997..b7920944 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,8 @@ # Docs .cache/ docs/reference/* +./examples/MP/experiments +./examples/QM9 *doctrees* /site diff --git a/src/electrai/configs/MP/config.yaml b/src/electrai/configs/MP/config.yaml index c30fa1ec..3a296a32 100644 --- a/src/electrai/configs/MP/config.yaml +++ b/src/electrai/configs/MP/config.yaml @@ -40,3 +40,6 @@ wb_pname: mp-experiment # checkpoints ckpt_path: ./checkpoints +save_pred: true +log_dir: ./logs +out_dir: ./predictions diff --git a/src/electrai/dataloader/dataset.py b/src/electrai/dataloader/dataset.py index b3321143..1010cf0c 100644 --- a/src/electrai/dataloader/dataset.py +++ b/src/electrai/dataloader/dataset.py @@ -62,6 +62,22 @@ def setup(self, stage=None): elif stage == "test": self.test_set = self.subsets["test"] + def setup(self, stage=None): + dataset = RhoData( + self.root, precision=self.precision, augmentation=self.augmentation + ) + self.subsets = split_data( + dataset, + val_frac=self.val_frac, + split_file=self.split_file, + random_seed=self.random_seed, + ) + if stage == "fit": + self.train_set = self.subsets["train"] + self.val_set = self.subsets["validation"] + elif stage == "test": + self.test_set = self.subsets["test"] + def train_dataloader(self): return DataLoader( self.train_set, @@ -69,8 +85,6 @@ def train_dataloader(self): num_workers=self.train_workers, shuffle=True, collate_fn=collate_fn, - pin_memory=self.pin_memory, - drop_last=self.drop_last, ) def val_dataloader(self): @@ -79,9 +93,6 @@ def val_dataloader(self): self.batch_size, num_workers=self.val_workers, shuffle=False, - collate_fn=collate_fn, - pin_memory=self.pin_memory, - drop_last=self.drop_last, ) def test_dataloader(self): @@ -90,10 +101,11 @@ def test_dataloader(self): batch_size=1, num_workers=self.val_workers, collate_fn=collate_fn, - pin_memory=self.pin_memory, - drop_last=self.drop_last, ) + def on_exception(self, exception: BaseException) -> None: + return + class RhoData(Dataset): def __init__(self, datapath: str, precision: str, augmentation: bool, **kwargs): diff --git a/src/electrai/entrypoints/main.py b/src/electrai/entrypoints/main.py index a98ea3a1..69afefd7 100644 --- a/src/electrai/entrypoints/main.py +++ b/src/electrai/entrypoints/main.py @@ -2,8 +2,12 @@ import argparse +import torch +from src.electrai.entrypoints.test import test from src.electrai.entrypoints.train import train +torch.backends.cudnn.conv.fp32_precision = "tf32" + def main() -> None: """Entry point. @@ -24,10 +28,15 @@ def main() -> None: train_parser = subparsers.add_parser("train", help="Train the model") train_parser.add_argument("--config", type=str, required=True) + test_parser = subparsers.add_parser("test", help="Evaluate the model") + test_parser.add_argument("--config", type=str, required=True) + args = parser.parse_args() if args.command == "train": train(args) + elif args.command == "test": + test(args) else: raise ValueError(f"Unknown command: {args.command}") diff --git a/src/electrai/entrypoints/test.py b/src/electrai/entrypoints/test.py new file mode 100644 index 00000000..acc7eb2c --- /dev/null +++ b/src/electrai/entrypoints/test.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +from pathlib import Path +from types import SimpleNamespace + +import torch +import yaml +from hydra.utils import instantiate +from lightning.pytorch import Trainer +from src.electrai.lightning import LightningGenerator + + +def test(args): + # ----------------------------- + # Load YAML config + # ----------------------------- + config_path = Path(args.config) + with Path.open(config_path) as f: + cfg_dict = yaml.safe_load(f) + cfg = SimpleNamespace(**cfg_dict) + + # ----------------------------- + # Data + # ----------------------------- + datamodule = instantiate(cfg.data) + + # ----------------------------- + # Model (LightningModule handles architecture + loss + optimizer) + # ----------------------------- + lit_model = LightningGenerator(cfg) + lit_model.test_cfg = SimpleNamespace(log_dir=cfg.log_dir, out_dir=cfg.out_dir) + + # ----------------------------- + # Callback + # ----------------------------- + ckpt_path = Path(getattr(cfg, "ckpt_path", "./checkpoints")) + + # ----------------------------- + # Trainer + # ----------------------------- + if cfg.save_pred: + out_dir = Path(getattr(cfg, "out_dir", "predictions")) + out_dir.mkdir(exist_ok=True, parents=True) + else: + out_dir = None + log_dir = Path(getattr(cfg, "log_dir", "logs")) + tmp_dir = log_dir / "tmp" + for directory in [log_dir, tmp_dir]: + directory.mkdir(exist_ok=True, parents=True) + trainer = Trainer( + logger=None, + callbacks=None, + accelerator="gpu" if torch.cuda.is_available() else "cpu", + devices=1, + precision=cfg.model_precision, + ) + + lit_model.test_cfg = SimpleNamespace( + log_dir=log_dir, out_dir=out_dir, tmp_dir=tmp_dir, save_pred=cfg.save_pred + ) + + # ----------------------------- + # Train + # ----------------------------- + ckpt = ckpt_path / "last.ckpt" + if not ckpt.exists(): + raise FileNotFoundError(f"Checkpoint not found: {ckpt}") + + trainer.test(model=lit_model, datamodule=datamodule, ckpt_path=ckpt) diff --git a/src/electrai/entrypoints/train.py b/src/electrai/entrypoints/train.py index 0dd54365..9bab7572 100644 --- a/src/electrai/entrypoints/train.py +++ b/src/electrai/entrypoints/train.py @@ -11,8 +11,6 @@ from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint from src.electrai.lightning import LightningGenerator -torch.backends.cudnn.conv.fp32_precision = "tf32" - def train(args): # ----------------------------- diff --git a/src/electrai/lightning.py b/src/electrai/lightning.py index 7a0c904c..956430f2 100644 --- a/src/electrai/lightning.py +++ b/src/electrai/lightning.py @@ -1,6 +1,11 @@ from __future__ import annotations +import shutil +import time + +import numpy as np import torch +import torch.distributed as dist from lightning.pytorch import LightningModule from src.electrai.model.loss.charge import NormMAE from src.electrai.model.srgan_layernorm_pbc import GeneratorResNet @@ -79,3 +84,103 @@ def configure_optimizers(self): optimizer, [linsch, cossch], milestones=[self.cfg.warmup_length] ) return [optimizer], [scheduler] + + def on_test_start(self): + self.log_dir = self.test_cfg.log_dir + self.out_dir = self.test_cfg.out_dir + self.tmp_dir = self.test_cfg.tmp_dir + self.save_pred = self.test_cfg.save_pred + self.test_outputs = [] + + def test_step(self, batch, batch_idx): + x = batch["data"] + y = batch["label"] + indices = batch["index"] + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + preds = self(x) + loss = self.loss_fn(preds, y) + end.record() + + torch.cuda.synchronize() + elapsed = start.elapsed_time(end) + + self.log("test_loss", loss, prog_bar=True, sync_dist=True) + + out = { + "target": y.detach().cpu(), + "index": indices, + "nmae": loss.detach().cpu(), + "duration": elapsed, + } + if self.save_pred: + out["pred"] = preds.detach().cpu() + return out + + def on_test_batch_end(self, outputs, batch, batch_idx): + indices = outputs["index"] + nmae = outputs["nmae"] + + if self.save_pred: + preds = outputs["pred"] + for i in range(len(indices)): + idx = indices[i] + np.save( + self.out_dir / f"rank_{self.global_rank}_{idx}.npy", + preds[i].squeeze(0).cpu().numpy(), + ) + + if isinstance(nmae, torch.Tensor) and nmae.ndim == 0: + nmae = nmae.unsqueeze(0) + tmp_csv = ( + self.tmp_dir / f"metrics_rank_{self.global_rank}_batch_{batch_idx}.csv" + ) + with open(tmp_csv, "w") as f: + for idx, n in zip(indices, nmae, strict=True): + f.write(f"rank_{self.global_rank},{idx},{n.item()}\n") + + def on_test_epoch_end(self): + is_dist = dist.is_available() and dist.is_initialized() + rank = dist.get_rank() if is_dist else 0 + + # Count only files written by THIS rank + local_count = len(list(self.tmp_dir.glob(f"metrics_rank_{rank}_batch_*.csv"))) + + if is_dist: + count_tensor = torch.tensor( + [local_count], dtype=torch.long, device=self.device + ) + dist.all_reduce(count_tensor, op=dist.ReduceOp.SUM) + expected_total = int(count_tensor.item()) + dist.barrier() + else: + expected_total = local_count + + final_csv = self.log_dir / "metrics.csv" + + if self.global_rank == 0: + retries = 0 + all_tmp_csvs = sorted(self.tmp_dir.glob("metrics_rank_*_batch_*.csv")) + while len(all_tmp_csvs) < expected_total and retries < 60: + time.sleep(1) + all_tmp_csvs = sorted(self.tmp_dir.glob("metrics_rank_*_batch_*.csv")) + retries += 1 + + if len(all_tmp_csvs) < expected_total: + raise RuntimeError( + f"Expected {expected_total} CSV files but found {len(all_tmp_csvs)}." + ) + + with open(final_csv, "w") as f_out: + f_out.write("rank,index,nmae\n") + for tmp_csv in all_tmp_csvs: + with open(tmp_csv) as f_in: + for line in f_in: + f_out.write(line) + + shutil.rmtree(self.tmp_dir, ignore_errors=True) + + if is_dist: + dist.barrier()