diff --git a/.gitignore b/.gitignore index c2ade997..b7920944 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,8 @@ # Docs .cache/ docs/reference/* +./examples/MP/experiments +./examples/QM9 *doctrees* /site diff --git a/src/electrai/configs/MP/config.yaml b/src/electrai/configs/MP/config.yaml index b0a93ce7..3da0a1a0 100644 --- a/src/electrai/configs/MP/config.yaml +++ b/src/electrai/configs/MP/config.yaml @@ -39,3 +39,6 @@ wb_pname: mp-experiment # checkpoints ckpt_path: ./checkpoints +save_pred: true +log_dir: ./logs +out_dir: ./predictions diff --git a/src/electrai/entrypoints/test.py b/src/electrai/entrypoints/test.py index 415600d4..1fe69a92 100644 --- a/src/electrai/entrypoints/test.py +++ b/src/electrai/entrypoints/test.py @@ -39,6 +39,15 @@ def test(args): # ----------------------------- # Trainer # ----------------------------- + if cfg.save_pred: + out_dir = Path(getattr(cfg, "out_dir", "predictions")) + out_dir.mkdir(exist_ok=True, parents=True) + else: + out_dir = None + log_dir = Path(getattr(cfg, "log_dir", "logs")) + tmp_dir = log_dir / "tmp" + for directory in [log_dir, tmp_dir]: + directory.mkdir(exist_ok=True, parents=True) trainer = Trainer( logger=None, callbacks=None, @@ -47,6 +56,10 @@ def test(args): precision=cfg.model_precision, ) + lit_model.test_cfg = SimpleNamespace( + log_dir=log_dir, out_dir=out_dir, tmp_dir=tmp_dir, save_pred=cfg.save_pred + ) + # ----------------------------- # Train # ----------------------------- diff --git a/src/electrai/lightning.py b/src/electrai/lightning.py index 1e674074..3a384cae 100644 --- a/src/electrai/lightning.py +++ b/src/electrai/lightning.py @@ -1,10 +1,11 @@ from __future__ import annotations +import shutil import time -from pathlib import Path import numpy as np import torch +import torch.distributed as dist from lightning.pytorch import LightningModule from src.electrai.model.loss.charge import NormMAE from src.electrai.model.srgan_layernorm_pbc import GeneratorResNet @@ -87,6 +88,8 @@ def configure_optimizers(self): def on_test_start(self): self.log_dir = self.test_cfg.log_dir self.out_dir = self.test_cfg.out_dir + self.tmp_dir = self.test_cfg.tmp_dir + self.save_pred = self.test_cfg.save_pred self.test_outputs = [] def test_step(self, batch, batch_idx): @@ -100,50 +103,78 @@ def test_step(self, batch, batch_idx): self.log("test_loss", loss, prog_bar=True, sync_dist=True) - return { - "pred": preds.detach().cpu(), + out = { "target": y.detach().cpu(), "index": indices, "nmae": loss.detach().cpu(), - "time": time.time() - start_time, # + batch["load_time"][0], ??? + "time": time.time() - start_time, } + if self.save_pred: + out["pred"] = preds.detach().cpu() + return out def on_test_batch_end(self, outputs, batch, batch_idx): - if self.out_dir is not None: - out_dir = Path(self.out_dir) - out_dir.mkdir(exist_ok=True, parents=True) + indices = outputs["index"] + nmae = outputs["nmae"] + if self.save_pred: preds = outputs["pred"] - indices = outputs["index"] - for i in range(len(indices)): idx = indices[i] - pred_i = preds[i].numpy() - np.save(out_dir / f"{idx}.npy", pred_i) - - self.test_outputs.append(outputs) + np.save( + self.out_dir / f"rank_{self.global_rank}_{idx}.npy", + preds[i].squeeze(0).cpu().numpy(), + ) + + if isinstance(nmae, torch.Tensor) and nmae.ndim == 0: + nmae = nmae.unsqueeze(0) + tmp_csv = ( + self.tmp_dir / f"metrics_rank_{self.global_rank}_batch_{batch_idx}.csv" + ) + with open(tmp_csv, "w") as f: + for idx, n in zip(indices, nmae, strict=True): + f.write(f"rank_{self.global_rank},{idx},{n.item()}\n") def on_test_epoch_end(self): - indices = [] - nmae_chunks = [] - - for output in self.test_outputs: - indices.extend(list(output["index"])) - - batch_nmae = output["nmae"] - if batch_nmae.ndim == 0: - nmae_chunks.append(batch_nmae.unsqueeze(0)) - else: - nmae_chunks.append(batch_nmae) - - all_nmae = torch.cat(nmae_chunks, dim=0) - - if self.log_dir is not None: - log_dir = Path(self.log_dir) - log_dir.mkdir(exist_ok=True, parents=True) - csv_path = log_dir / "metrics.csv" - - with open(csv_path, "w") as f: - f.write("index,nmae\n") - for ind, err in zip(indices, all_nmae.tolist(), strict=False): - f.write(f"{ind},{err}\n") + is_dist = dist.is_available() and dist.is_initialized() + rank = dist.get_rank() if is_dist else 0 + + # Count only files written by THIS rank + local_count = len(list(self.tmp_dir.glob(f"metrics_rank_{rank}_batch_*.csv"))) + + if is_dist: + count_tensor = torch.tensor( + [local_count], dtype=torch.long, device=self.device + ) + dist.all_reduce(count_tensor, op=dist.ReduceOp.SUM) + expected_total = int(count_tensor.item()) + dist.barrier() + else: + expected_total = local_count + + final_csv = self.log_dir / "metrics.csv" + + if self.global_rank == 0: + retries = 0 + all_tmp_csvs = sorted(self.tmp_dir.glob("metrics_rank_*_batch_*.csv")) + while len(all_tmp_csvs) < expected_total and retries < 60: + time.sleep(1) + all_tmp_csvs = sorted(self.tmp_dir.glob("metrics_rank_*_batch_*.csv")) + retries += 1 + + if len(all_tmp_csvs) < expected_total: + raise RuntimeError( + f"Expected {expected_total} CSV files but found {len(all_tmp_csvs)}." + ) + + with open(final_csv, "w") as f_out: + f_out.write("rank,index,nmae\n") + for tmp_csv in all_tmp_csvs: + with open(tmp_csv) as f_in: + for line in f_in: + f_out.write(line) + + shutil.rmtree(self.tmp_dir, ignore_errors=True) + + if is_dist: + dist.barrier()