From c6cdd0d320a8a9760e1664f3f2d28c2eb0720309 Mon Sep 17 00:00:00 2001 From: Hananeh Oliaei Date: Sat, 10 Jan 2026 20:15:18 -0500 Subject: [PATCH 1/9] test features --- src/electrai/entrypoints/main.py | 6 +++ src/electrai/entrypoints/test.py | 60 ++++++++++++++++++++++++++++ src/electrai/lightning.py | 68 ++++++++++++++++++++++++++++++++ 3 files changed, 134 insertions(+) create mode 100644 src/electrai/entrypoints/test.py diff --git a/src/electrai/entrypoints/main.py b/src/electrai/entrypoints/main.py index a98ea3a1..4293e1f5 100644 --- a/src/electrai/entrypoints/main.py +++ b/src/electrai/entrypoints/main.py @@ -2,6 +2,7 @@ import argparse +from src.electrai.entrypoints.test import test from src.electrai.entrypoints.train import train @@ -24,10 +25,15 @@ def main() -> None: train_parser = subparsers.add_parser("train", help="Train the model") train_parser.add_argument("--config", type=str, required=True) + test_parser = subparsers.add_parser("test", help="Test the model") + test_parser.add_argument("--config", type=str, required=True) + args = parser.parse_args() if args.command == "train": train(args) + if args.command == "test": + test(args) else: raise ValueError(f"Unknown command: {args.command}") diff --git a/src/electrai/entrypoints/test.py b/src/electrai/entrypoints/test.py new file mode 100644 index 00000000..145497a3 --- /dev/null +++ b/src/electrai/entrypoints/test.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +from pathlib import Path +from types import SimpleNamespace + +import torch +import yaml +from hydra.utils import instantiate +from lightning.pytorch import Trainer +from src.electrai.lightning import LightningGenerator + +torch.backends.cudnn.conv.fp32_precision = "tf32" + + +def test(args): + # ----------------------------- + # Load YAML config + # ----------------------------- + config_path = Path(args.config) + with Path.open(config_path) as f: + cfg_dict = yaml.safe_load(f) + cfg = SimpleNamespace(**cfg_dict) + + # ----------------------------- + # Data + # ----------------------------- + datamodule = instantiate(cfg.data) + test_loader = datamodule.test_dataloader() + + # ----------------------------- + # Model (LightningModule handles architecture + loss + optimizer) + # ----------------------------- + lit_model = LightningGenerator(cfg) + + # ----------------------------- + # Callback + # ----------------------------- + ckpt_path = Path(getattr(cfg, "ckpt_path", "./checkpoints")) + + # ----------------------------- + # Trainer + # ----------------------------- + trainer = Trainer( + logger=None, + callbacks=None, + accelerator="gpu" if torch.cuda.is_available() else "cpu", + devices=1, + precision=cfg.model_precision, + ) + lit_model.test_cfg = SimpleNamespace(log_dir=cfg.log_dir, out_dir=cfg.out_dir) + + # ----------------------------- + # Train + # ----------------------------- + ckpt = ckpt_path / "last.ckpt" + trainer.test( + model=lit_model, + dataloaders=test_loader, + ckpt_path=ckpt if ckpt.exists() else None, + ) diff --git a/src/electrai/lightning.py b/src/electrai/lightning.py index 7a0c904c..9ba7033d 100644 --- a/src/electrai/lightning.py +++ b/src/electrai/lightning.py @@ -1,5 +1,9 @@ from __future__ import annotations +import time +from pathlib import Path + +import numpy as np import torch from lightning.pytorch import LightningModule from src.electrai.model.loss.charge import NormMAE @@ -79,3 +83,67 @@ def configure_optimizers(self): optimizer, [linsch, cossch], milestones=[self.cfg.warmup_length] ) return [optimizer], [scheduler] + + def on_test_start(self): + self.log_dir = self.test_cfg.log_dir + self.out_dir = self.test_cfg.out_dir + self.test_outputs = [] + + def test_step(self, batch, batch_idx): + start_time = time.time() + x = batch["data"] + y = batch["label"] + indices = batch["index"] + + preds = self(x) + loss = self.loss_fn(preds, y) + + self.log("test_loss", loss, prog_bar=True, sync_dist=True) + + return { + "pred": preds.detach().cpu(), + "target": y.detach().cpu(), + "index": indices, + "nmae": loss.detach().cpu(), + "time": time.time() - start_time, # + batch["load_time"][0], ??? + } + + def on_test_batch_end(self, outputs, batch, batch_idx): + if self.out_dir is not None: + out_dir = Path(self.out_dir) + out_dir.mkdir(exist_ok=True, parents=True) + + preds = outputs["pred"] + indices = outputs["index"] + + for i in range(len(indices)): + idx = indices[i] + pred_i = preds[i].numpy() + np.save(out_dir / f"{idx}.npy", pred_i) + + self.test_outputs.append(outputs) + + def on_test_epoch_end(self): + index = [] + nmae_all = [] + + for o in self.test_outputs: + index.extend(list(o["index"])) + + n = o["nmae"] + if n.ndim == 0: + nmae_all.append(n.unsqueeze(0)) + else: + nmae_all.append(n) + + nmae = torch.cat(nmae_all, dim=0) + + if self.log_dir is not None: + log_dir = Path(self.log_dir) + log_dir.mkdir(exist_ok=True, parents=True) + csv_path = Path(self.log_dir) / "metrics.csv" + + with open(csv_path, "w") as f: + f.write("index,nmae\n") + for ind, err in zip(index, nmae.tolist(), strict=False): + f.write(f"{ind},{err}\n") From 8942188834c1bfbe33453ae111f83f0b09ac5e05 Mon Sep 17 00:00:00 2001 From: Hananeh Oliaei Date: Mon, 26 Jan 2026 10:44:59 -0500 Subject: [PATCH 2/9] Updated argparse help message and if condition --- src/electrai/entrypoints/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/electrai/entrypoints/main.py b/src/electrai/entrypoints/main.py index 4293e1f5..a734b7a8 100644 --- a/src/electrai/entrypoints/main.py +++ b/src/electrai/entrypoints/main.py @@ -25,14 +25,14 @@ def main() -> None: train_parser = subparsers.add_parser("train", help="Train the model") train_parser.add_argument("--config", type=str, required=True) - test_parser = subparsers.add_parser("test", help="Test the model") + test_parser = subparsers.add_parser("test", help="Evaluate the model") test_parser.add_argument("--config", type=str, required=True) args = parser.parse_args() if args.command == "train": train(args) - if args.command == "test": + elif args.command == "test": test(args) else: raise ValueError(f"Unknown command: {args.command}") From ad01b7bed516f32850cb0443e9d3ca717b9c7382 Mon Sep 17 00:00:00 2001 From: Hananeh Oliaei Date: Mon, 26 Jan 2026 10:48:33 -0500 Subject: [PATCH 3/9] raise error if checkpoint not found for model evaluation --- src/electrai/entrypoints/test.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/electrai/entrypoints/test.py b/src/electrai/entrypoints/test.py index 145497a3..2f097caf 100644 --- a/src/electrai/entrypoints/test.py +++ b/src/electrai/entrypoints/test.py @@ -53,8 +53,7 @@ def test(args): # Train # ----------------------------- ckpt = ckpt_path / "last.ckpt" - trainer.test( - model=lit_model, - dataloaders=test_loader, - ckpt_path=ckpt if ckpt.exists() else None, - ) + if not ckpt.exists(): + raise FileNotFoundError(f"Checkpoint not found: {ckpt}") + + trainer.test(model=lit_model, dataloaders=test_loader, ckpt_path=ckpt) From a4ed3bf0e5db336f88b9d2d8457cf53de97e825b Mon Sep 17 00:00:00 2001 From: Hananeh Oliaei Date: Mon, 26 Jan 2026 10:51:20 -0500 Subject: [PATCH 4/9] moved test configuration --- src/electrai/entrypoints/test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/electrai/entrypoints/test.py b/src/electrai/entrypoints/test.py index 2f097caf..3d9584d4 100644 --- a/src/electrai/entrypoints/test.py +++ b/src/electrai/entrypoints/test.py @@ -31,6 +31,7 @@ def test(args): # Model (LightningModule handles architecture + loss + optimizer) # ----------------------------- lit_model = LightningGenerator(cfg) + lit_model.test_cfg = SimpleNamespace(log_dir=cfg.log_dir, out_dir=cfg.out_dir) # ----------------------------- # Callback @@ -47,7 +48,6 @@ def test(args): devices=1, precision=cfg.model_precision, ) - lit_model.test_cfg = SimpleNamespace(log_dir=cfg.log_dir, out_dir=cfg.out_dir) # ----------------------------- # Train From 9be25c6cff0b51e809c37c6cf8b9b4245bd9c170 Mon Sep 17 00:00:00 2001 From: Hananeh Oliaei Date: Mon, 26 Jan 2026 11:09:19 -0500 Subject: [PATCH 5/9] changed variables names for clarity --- src/electrai/lightning.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/electrai/lightning.py b/src/electrai/lightning.py index 9ba7033d..1e674074 100644 --- a/src/electrai/lightning.py +++ b/src/electrai/lightning.py @@ -124,26 +124,26 @@ def on_test_batch_end(self, outputs, batch, batch_idx): self.test_outputs.append(outputs) def on_test_epoch_end(self): - index = [] - nmae_all = [] + indices = [] + nmae_chunks = [] - for o in self.test_outputs: - index.extend(list(o["index"])) + for output in self.test_outputs: + indices.extend(list(output["index"])) - n = o["nmae"] - if n.ndim == 0: - nmae_all.append(n.unsqueeze(0)) + batch_nmae = output["nmae"] + if batch_nmae.ndim == 0: + nmae_chunks.append(batch_nmae.unsqueeze(0)) else: - nmae_all.append(n) + nmae_chunks.append(batch_nmae) - nmae = torch.cat(nmae_all, dim=0) + all_nmae = torch.cat(nmae_chunks, dim=0) if self.log_dir is not None: log_dir = Path(self.log_dir) log_dir.mkdir(exist_ok=True, parents=True) - csv_path = Path(self.log_dir) / "metrics.csv" + csv_path = log_dir / "metrics.csv" with open(csv_path, "w") as f: f.write("index,nmae\n") - for ind, err in zip(index, nmae.tolist(), strict=False): + for ind, err in zip(indices, all_nmae.tolist(), strict=False): f.write(f"{ind},{err}\n") From 1991c724ff9914e89bef5f262165ef3a724a541b Mon Sep 17 00:00:00 2001 From: Hananeh Oliaei Date: Wed, 4 Feb 2026 17:50:17 -0500 Subject: [PATCH 6/9] moved the precision mode to the main script to avoid code duplication in train and test modules --- src/electrai/entrypoints/main.py | 3 +++ src/electrai/entrypoints/test.py | 2 -- src/electrai/entrypoints/train.py | 2 -- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/electrai/entrypoints/main.py b/src/electrai/entrypoints/main.py index a734b7a8..69afefd7 100644 --- a/src/electrai/entrypoints/main.py +++ b/src/electrai/entrypoints/main.py @@ -2,9 +2,12 @@ import argparse +import torch from src.electrai.entrypoints.test import test from src.electrai.entrypoints.train import train +torch.backends.cudnn.conv.fp32_precision = "tf32" + def main() -> None: """Entry point. diff --git a/src/electrai/entrypoints/test.py b/src/electrai/entrypoints/test.py index 3d9584d4..415600d4 100644 --- a/src/electrai/entrypoints/test.py +++ b/src/electrai/entrypoints/test.py @@ -9,8 +9,6 @@ from lightning.pytorch import Trainer from src.electrai.lightning import LightningGenerator -torch.backends.cudnn.conv.fp32_precision = "tf32" - def test(args): # ----------------------------- diff --git a/src/electrai/entrypoints/train.py b/src/electrai/entrypoints/train.py index 188fae23..f4201d9b 100644 --- a/src/electrai/entrypoints/train.py +++ b/src/electrai/entrypoints/train.py @@ -11,8 +11,6 @@ from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint from src.electrai.lightning import LightningGenerator -torch.backends.cudnn.conv.fp32_precision = "tf32" - def train(args): # ----------------------------- From cad9722bb9961d6dc1ffc317680be64d5eebd193 Mon Sep 17 00:00:00 2001 From: Hananeh Oliaei Date: Fri, 6 Feb 2026 22:58:31 -0500 Subject: [PATCH 7/9] saving output from different ranks (#62) Handles logging/saving the performance metric across multiple ranks. --------- Co-authored-by: Hananeh Oliaei Co-authored-by: Betsy Cannon --- .gitignore | 2 + src/electrai/configs/MP/config.yaml | 3 + src/electrai/entrypoints/test.py | 13 ++++ src/electrai/lightning.py | 103 ++++++++++++++++++---------- 4 files changed, 85 insertions(+), 36 deletions(-) diff --git a/.gitignore b/.gitignore index c2ade997..b7920944 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,8 @@ # Docs .cache/ docs/reference/* +./examples/MP/experiments +./examples/QM9 *doctrees* /site diff --git a/src/electrai/configs/MP/config.yaml b/src/electrai/configs/MP/config.yaml index b0a93ce7..3da0a1a0 100644 --- a/src/electrai/configs/MP/config.yaml +++ b/src/electrai/configs/MP/config.yaml @@ -39,3 +39,6 @@ wb_pname: mp-experiment # checkpoints ckpt_path: ./checkpoints +save_pred: true +log_dir: ./logs +out_dir: ./predictions diff --git a/src/electrai/entrypoints/test.py b/src/electrai/entrypoints/test.py index 415600d4..1fe69a92 100644 --- a/src/electrai/entrypoints/test.py +++ b/src/electrai/entrypoints/test.py @@ -39,6 +39,15 @@ def test(args): # ----------------------------- # Trainer # ----------------------------- + if cfg.save_pred: + out_dir = Path(getattr(cfg, "out_dir", "predictions")) + out_dir.mkdir(exist_ok=True, parents=True) + else: + out_dir = None + log_dir = Path(getattr(cfg, "log_dir", "logs")) + tmp_dir = log_dir / "tmp" + for directory in [log_dir, tmp_dir]: + directory.mkdir(exist_ok=True, parents=True) trainer = Trainer( logger=None, callbacks=None, @@ -47,6 +56,10 @@ def test(args): precision=cfg.model_precision, ) + lit_model.test_cfg = SimpleNamespace( + log_dir=log_dir, out_dir=out_dir, tmp_dir=tmp_dir, save_pred=cfg.save_pred + ) + # ----------------------------- # Train # ----------------------------- diff --git a/src/electrai/lightning.py b/src/electrai/lightning.py index 1e674074..3a384cae 100644 --- a/src/electrai/lightning.py +++ b/src/electrai/lightning.py @@ -1,10 +1,11 @@ from __future__ import annotations +import shutil import time -from pathlib import Path import numpy as np import torch +import torch.distributed as dist from lightning.pytorch import LightningModule from src.electrai.model.loss.charge import NormMAE from src.electrai.model.srgan_layernorm_pbc import GeneratorResNet @@ -87,6 +88,8 @@ def configure_optimizers(self): def on_test_start(self): self.log_dir = self.test_cfg.log_dir self.out_dir = self.test_cfg.out_dir + self.tmp_dir = self.test_cfg.tmp_dir + self.save_pred = self.test_cfg.save_pred self.test_outputs = [] def test_step(self, batch, batch_idx): @@ -100,50 +103,78 @@ def test_step(self, batch, batch_idx): self.log("test_loss", loss, prog_bar=True, sync_dist=True) - return { - "pred": preds.detach().cpu(), + out = { "target": y.detach().cpu(), "index": indices, "nmae": loss.detach().cpu(), - "time": time.time() - start_time, # + batch["load_time"][0], ??? + "time": time.time() - start_time, } + if self.save_pred: + out["pred"] = preds.detach().cpu() + return out def on_test_batch_end(self, outputs, batch, batch_idx): - if self.out_dir is not None: - out_dir = Path(self.out_dir) - out_dir.mkdir(exist_ok=True, parents=True) + indices = outputs["index"] + nmae = outputs["nmae"] + if self.save_pred: preds = outputs["pred"] - indices = outputs["index"] - for i in range(len(indices)): idx = indices[i] - pred_i = preds[i].numpy() - np.save(out_dir / f"{idx}.npy", pred_i) - - self.test_outputs.append(outputs) + np.save( + self.out_dir / f"rank_{self.global_rank}_{idx}.npy", + preds[i].squeeze(0).cpu().numpy(), + ) + + if isinstance(nmae, torch.Tensor) and nmae.ndim == 0: + nmae = nmae.unsqueeze(0) + tmp_csv = ( + self.tmp_dir / f"metrics_rank_{self.global_rank}_batch_{batch_idx}.csv" + ) + with open(tmp_csv, "w") as f: + for idx, n in zip(indices, nmae, strict=True): + f.write(f"rank_{self.global_rank},{idx},{n.item()}\n") def on_test_epoch_end(self): - indices = [] - nmae_chunks = [] - - for output in self.test_outputs: - indices.extend(list(output["index"])) - - batch_nmae = output["nmae"] - if batch_nmae.ndim == 0: - nmae_chunks.append(batch_nmae.unsqueeze(0)) - else: - nmae_chunks.append(batch_nmae) - - all_nmae = torch.cat(nmae_chunks, dim=0) - - if self.log_dir is not None: - log_dir = Path(self.log_dir) - log_dir.mkdir(exist_ok=True, parents=True) - csv_path = log_dir / "metrics.csv" - - with open(csv_path, "w") as f: - f.write("index,nmae\n") - for ind, err in zip(indices, all_nmae.tolist(), strict=False): - f.write(f"{ind},{err}\n") + is_dist = dist.is_available() and dist.is_initialized() + rank = dist.get_rank() if is_dist else 0 + + # Count only files written by THIS rank + local_count = len(list(self.tmp_dir.glob(f"metrics_rank_{rank}_batch_*.csv"))) + + if is_dist: + count_tensor = torch.tensor( + [local_count], dtype=torch.long, device=self.device + ) + dist.all_reduce(count_tensor, op=dist.ReduceOp.SUM) + expected_total = int(count_tensor.item()) + dist.barrier() + else: + expected_total = local_count + + final_csv = self.log_dir / "metrics.csv" + + if self.global_rank == 0: + retries = 0 + all_tmp_csvs = sorted(self.tmp_dir.glob("metrics_rank_*_batch_*.csv")) + while len(all_tmp_csvs) < expected_total and retries < 60: + time.sleep(1) + all_tmp_csvs = sorted(self.tmp_dir.glob("metrics_rank_*_batch_*.csv")) + retries += 1 + + if len(all_tmp_csvs) < expected_total: + raise RuntimeError( + f"Expected {expected_total} CSV files but found {len(all_tmp_csvs)}." + ) + + with open(final_csv, "w") as f_out: + f_out.write("rank,index,nmae\n") + for tmp_csv in all_tmp_csvs: + with open(tmp_csv) as f_in: + for line in f_in: + f_out.write(line) + + shutil.rmtree(self.tmp_dir, ignore_errors=True) + + if is_dist: + dist.barrier() From 200d232ad29a9b2b50ac1a24ae17699c4d04c039 Mon Sep 17 00:00:00 2001 From: Hananeh Oliaei Date: Wed, 11 Feb 2026 19:23:18 -0500 Subject: [PATCH 8/9] updated sample configuration files --- src/electrai/configs/MP/config.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/src/electrai/configs/MP/config.yaml b/src/electrai/configs/MP/config.yaml index 3da0a1a0..3a296a32 100644 --- a/src/electrai/configs/MP/config.yaml +++ b/src/electrai/configs/MP/config.yaml @@ -11,6 +11,7 @@ data: val_frac: 0.005 drop_last: false augmentation: false + random_seed: 42 # downsample_label: 0 # downsample_data: 0 From b784005a81e71b965cbb4d8817b53f1dd09ff189 Mon Sep 17 00:00:00 2001 From: Hananeh Oliaei Date: Wed, 11 Feb 2026 19:26:10 -0500 Subject: [PATCH 9/9] updated dataloader and test scripts --- src/electrai/dataloader/dataset.py | 38 ++++++++++++++++++++++-------- src/electrai/dataloader/split.py | 17 +++++++++---- src/electrai/entrypoints/test.py | 3 +-- src/electrai/lightning.py | 10 ++++++-- 4 files changed, 50 insertions(+), 18 deletions(-) diff --git a/src/electrai/dataloader/dataset.py b/src/electrai/dataloader/dataset.py index 708b34f4..338cd214 100644 --- a/src/electrai/dataloader/dataset.py +++ b/src/electrai/dataloader/dataset.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING import torch +from lightning.pytorch import LightningDataModule from src.electrai.dataloader import utils from src.electrai.dataloader.collate import collate_fn from src.electrai.dataloader.split import split_data @@ -15,7 +16,7 @@ dtype_map = {"f32": torch.float32, "f16": torch.float16, "bf16": torch.bfloat16} -class RhoRead: +class RhoRead(LightningDataModule): def __init__( self, root: str | bytes | os.PathLike, @@ -28,6 +29,7 @@ def __init__( drop_last: bool = False, split_file: str | bytes | os.PathLike | None = None, augmentation: bool = False, + random_seed: int = 42, **kwargs, ): super().__init__() @@ -39,6 +41,9 @@ def __init__( self.val_frac = val_frac self.drop_last = drop_last self.split_file = split_file + self.precision = precision + self.augmentation = augmentation + self.random_seed = random_seed dataset = RhoData(self.root, precision=precision, augmentation=augmentation) @@ -46,14 +51,29 @@ def __init__( dataset, val_frac=self.val_frac, split_file=self.split_file ) + def setup(self, stage=None): + dataset = RhoData( + self.root, precision=self.precision, augmentation=self.augmentation + ) + self.subsets = split_data( + dataset, + val_frac=self.val_frac, + split_file=self.split_file, + random_seed=self.random_seed, + ) + if stage == "fit": + self.train_set = self.subsets["train"] + self.val_set = self.subsets["validation"] + elif stage == "test": + self.test_set = self.subsets["test"] + def train_dataloader(self): return DataLoader( self.subsets["train"], self.batch_size, num_workers=self.train_workers, shuffle=True, - # sampler=DistributedSampler(self.train_set, drop_last=self.drop_last), do we need this eventhough we use pytorch lightning? important question - collate_fn=collate_fn, # originally it was partial(collate_list_of_dicts, pin_memory=self.pin_memory) should we be concerned about partial and pin_memory? + collate_fn=collate_fn, ) def val_dataloader(self): @@ -61,9 +81,7 @@ def val_dataloader(self): self.subsets["validation"], self.batch_size, num_workers=self.val_workers, - shuffle=False, # I added this - # collate_fn=partial(collate_list_of_dicts, pin_memory=self.pin_memory), - # note: no sampler, so all devices will get full set + shuffle=False, ) def test_dataloader(self): @@ -71,12 +89,12 @@ def test_dataloader(self): self.subsets["test"], batch_size=1, num_workers=self.val_workers, - collate_fn=collate_fn, # partial(collate_list_of_dicts, pin_memory=self.pin_memory), - # note: distributed sampler will shuffle and distribute different parts of dataset - # to different nodes/devices - # sampler=DistributedEvalSampler(self.test_set), + collate_fn=collate_fn, ) + def on_exception(self, exception: BaseException) -> None: + return + class RhoData(Dataset): def __init__(self, datapath: str, precision: str, augmentation: bool, **kwargs): diff --git a/src/electrai/dataloader/split.py b/src/electrai/dataloader/split.py index 4aa18a16..6f8d876e 100644 --- a/src/electrai/dataloader/split.py +++ b/src/electrai/dataloader/split.py @@ -2,11 +2,16 @@ import json -import numpy as np -from torch.utils.data import Subset +import torch +from torch.utils.data import Dataset, Subset -def split_data(dataset, val_frac=0.005, split_file=None): +def split_data( + dataset: Dataset, + val_frac: float = 0.005, + split_file: str | None = None, + random_seed: int = 42, +): # Load or generate splits if split_file is not None: with open(split_file) as fp: @@ -14,7 +19,11 @@ def split_data(dataset, val_frac=0.005, split_file=None): else: data_size = len(dataset) validation_size = int(data_size * val_frac) - indices = np.random.permutation(data_size) + g = torch.Generator() + g.manual_seed(random_seed) + + indices = torch.randperm(data_size, generator=g) + splits = { "train": indices[validation_size:].tolist(), "validation": indices[:validation_size].tolist(), diff --git a/src/electrai/entrypoints/test.py b/src/electrai/entrypoints/test.py index 1fe69a92..acc7eb2c 100644 --- a/src/electrai/entrypoints/test.py +++ b/src/electrai/entrypoints/test.py @@ -23,7 +23,6 @@ def test(args): # Data # ----------------------------- datamodule = instantiate(cfg.data) - test_loader = datamodule.test_dataloader() # ----------------------------- # Model (LightningModule handles architecture + loss + optimizer) @@ -67,4 +66,4 @@ def test(args): if not ckpt.exists(): raise FileNotFoundError(f"Checkpoint not found: {ckpt}") - trainer.test(model=lit_model, dataloaders=test_loader, ckpt_path=ckpt) + trainer.test(model=lit_model, datamodule=datamodule, ckpt_path=ckpt) diff --git a/src/electrai/lightning.py b/src/electrai/lightning.py index 3a384cae..956430f2 100644 --- a/src/electrai/lightning.py +++ b/src/electrai/lightning.py @@ -93,13 +93,19 @@ def on_test_start(self): self.test_outputs = [] def test_step(self, batch, batch_idx): - start_time = time.time() x = batch["data"] y = batch["label"] indices = batch["index"] + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() preds = self(x) loss = self.loss_fn(preds, y) + end.record() + + torch.cuda.synchronize() + elapsed = start.elapsed_time(end) self.log("test_loss", loss, prog_bar=True, sync_dist=True) @@ -107,7 +113,7 @@ def test_step(self, batch, batch_idx): "target": y.detach().cpu(), "index": indices, "nmae": loss.detach().cpu(), - "time": time.time() - start_time, + "duration": elapsed, } if self.save_pred: out["pred"] = preds.detach().cpu()