Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Docs
.cache/
docs/reference/*
./examples/MP/experiments
./examples/QM9
*doctrees*
/site

Expand Down
3 changes: 3 additions & 0 deletions src/electrai/configs/MP/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,6 @@ wb_pname: mp-experiment

# checkpoints
ckpt_path: ./checkpoints
save_pred: true
log_dir: ./logs
out_dir: ./predictions
26 changes: 19 additions & 7 deletions src/electrai/dataloader/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,29 @@ def setup(self, stage=None):
elif stage == "test":
self.test_set = self.subsets["test"]

def setup(self, stage=None):
dataset = RhoData(
self.root, precision=self.precision, augmentation=self.augmentation
)
self.subsets = split_data(
dataset,
val_frac=self.val_frac,
split_file=self.split_file,
random_seed=self.random_seed,
)
if stage == "fit":
self.train_set = self.subsets["train"]
self.val_set = self.subsets["validation"]
elif stage == "test":
self.test_set = self.subsets["test"]

def train_dataloader(self):
return DataLoader(
self.train_set,
self.batch_size,
num_workers=self.train_workers,
shuffle=True,
collate_fn=collate_fn,
pin_memory=self.pin_memory,
drop_last=self.drop_last,
)

def val_dataloader(self):
Expand All @@ -79,9 +93,6 @@ def val_dataloader(self):
self.batch_size,
num_workers=self.val_workers,
shuffle=False,
collate_fn=collate_fn,
pin_memory=self.pin_memory,
drop_last=self.drop_last,
)

def test_dataloader(self):
Expand All @@ -90,10 +101,11 @@ def test_dataloader(self):
batch_size=1,
num_workers=self.val_workers,
collate_fn=collate_fn,
pin_memory=self.pin_memory,
drop_last=self.drop_last,
)

def on_exception(self, exception: BaseException) -> None:
return


class RhoData(Dataset):
def __init__(self, datapath: str, precision: str, augmentation: bool, **kwargs):
Expand Down
9 changes: 9 additions & 0 deletions src/electrai/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@

import argparse

import torch
from src.electrai.entrypoints.test import test
from src.electrai.entrypoints.train import train

torch.backends.cudnn.conv.fp32_precision = "tf32"


def main() -> None:
"""Entry point.
Expand All @@ -24,10 +28,15 @@ def main() -> None:
train_parser = subparsers.add_parser("train", help="Train the model")
train_parser.add_argument("--config", type=str, required=True)

test_parser = subparsers.add_parser("test", help="Evaluate the model")
test_parser.add_argument("--config", type=str, required=True)

args = parser.parse_args()

if args.command == "train":
train(args)
elif args.command == "test":
test(args)
else:
raise ValueError(f"Unknown command: {args.command}")

Expand Down
69 changes: 69 additions & 0 deletions src/electrai/entrypoints/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from __future__ import annotations

from pathlib import Path
from types import SimpleNamespace

import torch
import yaml
from hydra.utils import instantiate
from lightning.pytorch import Trainer
from src.electrai.lightning import LightningGenerator


def test(args):
# -----------------------------
# Load YAML config
# -----------------------------
config_path = Path(args.config)
with Path.open(config_path) as f:
cfg_dict = yaml.safe_load(f)
cfg = SimpleNamespace(**cfg_dict)

# -----------------------------
# Data
# -----------------------------
datamodule = instantiate(cfg.data)

# -----------------------------
# Model (LightningModule handles architecture + loss + optimizer)
# -----------------------------
lit_model = LightningGenerator(cfg)
lit_model.test_cfg = SimpleNamespace(log_dir=cfg.log_dir, out_dir=cfg.out_dir)

# -----------------------------
# Callback
# -----------------------------
ckpt_path = Path(getattr(cfg, "ckpt_path", "./checkpoints"))

# -----------------------------
# Trainer
# -----------------------------
if cfg.save_pred:
out_dir = Path(getattr(cfg, "out_dir", "predictions"))
out_dir.mkdir(exist_ok=True, parents=True)
else:
out_dir = None
log_dir = Path(getattr(cfg, "log_dir", "logs"))
tmp_dir = log_dir / "tmp"
for directory in [log_dir, tmp_dir]:
directory.mkdir(exist_ok=True, parents=True)
trainer = Trainer(
logger=None,
callbacks=None,
accelerator="gpu" if torch.cuda.is_available() else "cpu",
devices=1,
precision=cfg.model_precision,
)

lit_model.test_cfg = SimpleNamespace(
log_dir=log_dir, out_dir=out_dir, tmp_dir=tmp_dir, save_pred=cfg.save_pred
)

# -----------------------------
# Train
# -----------------------------
ckpt = ckpt_path / "last.ckpt"
if not ckpt.exists():
raise FileNotFoundError(f"Checkpoint not found: {ckpt}")

trainer.test(model=lit_model, datamodule=datamodule, ckpt_path=ckpt)
2 changes: 0 additions & 2 deletions src/electrai/entrypoints/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
from src.electrai.lightning import LightningGenerator

torch.backends.cudnn.conv.fp32_precision = "tf32"


def train(args):
# -----------------------------
Expand Down
105 changes: 105 additions & 0 deletions src/electrai/lightning.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from __future__ import annotations

import shutil
import time

import numpy as np
import torch
import torch.distributed as dist
from lightning.pytorch import LightningModule
from src.electrai.model.loss.charge import NormMAE
from src.electrai.model.srgan_layernorm_pbc import GeneratorResNet
Expand Down Expand Up @@ -79,3 +84,103 @@ def configure_optimizers(self):
optimizer, [linsch, cossch], milestones=[self.cfg.warmup_length]
)
return [optimizer], [scheduler]

def on_test_start(self):
self.log_dir = self.test_cfg.log_dir
self.out_dir = self.test_cfg.out_dir
self.tmp_dir = self.test_cfg.tmp_dir
self.save_pred = self.test_cfg.save_pred
self.test_outputs = []

def test_step(self, batch, batch_idx):
x = batch["data"]
y = batch["label"]
indices = batch["index"]
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

start.record()
preds = self(x)
loss = self.loss_fn(preds, y)
end.record()

torch.cuda.synchronize()
elapsed = start.elapsed_time(end)

self.log("test_loss", loss, prog_bar=True, sync_dist=True)

out = {
"target": y.detach().cpu(),
"index": indices,
"nmae": loss.detach().cpu(),
"duration": elapsed,
}
if self.save_pred:
out["pred"] = preds.detach().cpu()
return out

def on_test_batch_end(self, outputs, batch, batch_idx):
indices = outputs["index"]
nmae = outputs["nmae"]

if self.save_pred:
preds = outputs["pred"]
for i in range(len(indices)):
idx = indices[i]
np.save(
self.out_dir / f"rank_{self.global_rank}_{idx}.npy",
preds[i].squeeze(0).cpu().numpy(),
)

if isinstance(nmae, torch.Tensor) and nmae.ndim == 0:
nmae = nmae.unsqueeze(0)
tmp_csv = (
self.tmp_dir / f"metrics_rank_{self.global_rank}_batch_{batch_idx}.csv"
)
with open(tmp_csv, "w") as f:
for idx, n in zip(indices, nmae, strict=True):
f.write(f"rank_{self.global_rank},{idx},{n.item()}\n")

def on_test_epoch_end(self):
is_dist = dist.is_available() and dist.is_initialized()
rank = dist.get_rank() if is_dist else 0

# Count only files written by THIS rank
local_count = len(list(self.tmp_dir.glob(f"metrics_rank_{rank}_batch_*.csv")))

if is_dist:
count_tensor = torch.tensor(
[local_count], dtype=torch.long, device=self.device
)
dist.all_reduce(count_tensor, op=dist.ReduceOp.SUM)
expected_total = int(count_tensor.item())
dist.barrier()
else:
expected_total = local_count

final_csv = self.log_dir / "metrics.csv"

if self.global_rank == 0:
retries = 0
all_tmp_csvs = sorted(self.tmp_dir.glob("metrics_rank_*_batch_*.csv"))
while len(all_tmp_csvs) < expected_total and retries < 60:
time.sleep(1)
all_tmp_csvs = sorted(self.tmp_dir.glob("metrics_rank_*_batch_*.csv"))
retries += 1

if len(all_tmp_csvs) < expected_total:
raise RuntimeError(
f"Expected {expected_total} CSV files but found {len(all_tmp_csvs)}."
)

with open(final_csv, "w") as f_out:
f_out.write("rank,index,nmae\n")
for tmp_csv in all_tmp_csvs:
with open(tmp_csv) as f_in:
for line in f_in:
f_out.write(line)

shutil.rmtree(self.tmp_dir, ignore_errors=True)

if is_dist:
dist.barrier()