From 581030858528e90329bd2a673d3e0cce5b96c8ea Mon Sep 17 00:00:00 2001 From: Hananeh Oliaei Date: Sun, 11 Jan 2026 00:31:51 -0500 Subject: [PATCH 01/14] saving output from different ranks --- src/electrai/lightning.py | 58 +++++++++++++++++++++------------------ 1 file changed, 32 insertions(+), 26 deletions(-) diff --git a/src/electrai/lightning.py b/src/electrai/lightning.py index 9ba7033d..49b2e277 100644 --- a/src/electrai/lightning.py +++ b/src/electrai/lightning.py @@ -87,6 +87,14 @@ def configure_optimizers(self): def on_test_start(self): self.log_dir = self.test_cfg.log_dir self.out_dir = self.test_cfg.out_dir + if self.out_dir is not None: + self.out_dir = Path(self.out_dir) + self.out_dir.mkdir(exist_ok=True, parents=True) + if self.log_dir is not None: + self.log_dir = Path(self.log_dir) + self.log_dir.mkdir(exist_ok=True, parents=True) + self.tmp_dir = Path(self.out_dir) / "tmp" + self.tmp_dir.mkdir(exist_ok=True, parents=True) self.test_outputs = [] def test_step(self, batch, batch_idx): @@ -110,40 +118,38 @@ def test_step(self, batch, batch_idx): def on_test_batch_end(self, outputs, batch, batch_idx): if self.out_dir is not None: - out_dir = Path(self.out_dir) - out_dir.mkdir(exist_ok=True, parents=True) - preds = outputs["pred"] indices = outputs["index"] + nmae = outputs["nmae"] + # Save prediction files for i in range(len(indices)): idx = indices[i] - pred_i = preds[i].numpy() - np.save(out_dir / f"{idx}.npy", pred_i) + np.save(self.out_dir / f"{idx}.npy", preds[i].squeeze(0).cpu().numpy()) - self.test_outputs.append(outputs) + if self.log_dir is not None: + # Save batch-level CSV + if isinstance(nmae, torch.Tensor) and nmae.ndim == 0: + nmae = nmae.unsqueeze(0) + tmp_csv = self.tmp_dir / f"metrics_batch_{self.global_rank}_{batch_idx}.csv" + with open(tmp_csv, "w") as f: + for i, n in zip(indices, nmae, strict=False): + idx = i + f.write(f"{idx},{n.item()}\n") def on_test_epoch_end(self): - index = [] - nmae_all = [] - - for o in self.test_outputs: - index.extend(list(o["index"])) + if self.log_dir is None: + return - n = o["nmae"] - if n.ndim == 0: - nmae_all.append(n.unsqueeze(0)) - else: - nmae_all.append(n) + final_csv = self.log_dir / "metrics.csv" - nmae = torch.cat(nmae_all, dim=0) + # gather all batch CSVs + all_tmp_csvs = sorted(self.tmp_dir.glob("metrics_batch_*.csv")) - if self.log_dir is not None: - log_dir = Path(self.log_dir) - log_dir.mkdir(exist_ok=True, parents=True) - csv_path = Path(self.log_dir) / "metrics.csv" - - with open(csv_path, "w") as f: - f.write("index,nmae\n") - for ind, err in zip(index, nmae.tolist(), strict=False): - f.write(f"{ind},{err}\n") + # write final CSV with header + with open(final_csv, "w") as f_out: + f_out.write("index,nmae\n") + for tmp_csv in all_tmp_csvs: + with open(tmp_csv) as f_in: + for line in f_in: + f_out.write(line) From 0e2455aadd75e32fe30f9d73462da750bf683eb7 Mon Sep 17 00:00:00 2001 From: Hananeh Oliaei Date: Wed, 4 Feb 2026 14:19:05 -0500 Subject: [PATCH 02/14] Clean up log and output directory handling --- src/electrai/configs/MP/config.yaml | 2 ++ src/electrai/entrypoints/test.py | 6 +++++- src/electrai/lightning.py | 10 ++-------- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/electrai/configs/MP/config.yaml b/src/electrai/configs/MP/config.yaml index b0a93ce7..45e8e986 100644 --- a/src/electrai/configs/MP/config.yaml +++ b/src/electrai/configs/MP/config.yaml @@ -39,3 +39,5 @@ wb_pname: mp-experiment # checkpoints ckpt_path: ./checkpoints +log_dir: ./logs +out_dir: ./predictions diff --git a/src/electrai/entrypoints/test.py b/src/electrai/entrypoints/test.py index 145497a3..d287ebd5 100644 --- a/src/electrai/entrypoints/test.py +++ b/src/electrai/entrypoints/test.py @@ -47,7 +47,11 @@ def test(args): devices=1, precision=cfg.model_precision, ) - lit_model.test_cfg = SimpleNamespace(log_dir=cfg.log_dir, out_dir=cfg.out_dir) + out_dir = Path(getattr(cfg, "out_dir", "predictions")) + out_dir.mkdir(exist_ok=True, parents=True) + log_dir = Path(getattr(cfg, "log_dir", "logs")) + log_dir.mkdir(exist_ok=True, parents=True) + lit_model.test_cfg = SimpleNamespace(log_dir=log_dir, out_dir=out_dir) # ----------------------------- # Train diff --git a/src/electrai/lightning.py b/src/electrai/lightning.py index 49b2e277..834a8db6 100644 --- a/src/electrai/lightning.py +++ b/src/electrai/lightning.py @@ -87,14 +87,8 @@ def configure_optimizers(self): def on_test_start(self): self.log_dir = self.test_cfg.log_dir self.out_dir = self.test_cfg.out_dir - if self.out_dir is not None: - self.out_dir = Path(self.out_dir) - self.out_dir.mkdir(exist_ok=True, parents=True) - if self.log_dir is not None: - self.log_dir = Path(self.log_dir) - self.log_dir.mkdir(exist_ok=True, parents=True) - self.tmp_dir = Path(self.out_dir) / "tmp" - self.tmp_dir.mkdir(exist_ok=True, parents=True) + self.tmp_dir = Path(self.out_dir) / "tmp" + self.tmp_dir.mkdir(exist_ok=True, parents=True) self.test_outputs = [] def test_step(self, batch, batch_idx): From c3f93e0d4a258c1b6d23fe03f3f7138e5a03fc0c Mon Sep 17 00:00:00 2001 From: Hananeh Oliaei Date: Wed, 4 Feb 2026 14:22:58 -0500 Subject: [PATCH 03/14] correction --- src/electrai/lightning.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/electrai/lightning.py b/src/electrai/lightning.py index 834a8db6..277a9315 100644 --- a/src/electrai/lightning.py +++ b/src/electrai/lightning.py @@ -1,7 +1,6 @@ from __future__ import annotations import time -from pathlib import Path import numpy as np import torch @@ -87,7 +86,7 @@ def configure_optimizers(self): def on_test_start(self): self.log_dir = self.test_cfg.log_dir self.out_dir = self.test_cfg.out_dir - self.tmp_dir = Path(self.out_dir) / "tmp" + self.tmp_dir = self.out_dir / "tmp" self.tmp_dir.mkdir(exist_ok=True, parents=True) self.test_outputs = [] From 3ca47e9d0778996c4b4fbd45a602b01d0fff6810 Mon Sep 17 00:00:00 2001 From: Hananeh Oliaei Date: Wed, 4 Feb 2026 14:24:42 -0500 Subject: [PATCH 04/14] set strict=True --- src/electrai/lightning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/electrai/lightning.py b/src/electrai/lightning.py index 277a9315..5fe15825 100644 --- a/src/electrai/lightning.py +++ b/src/electrai/lightning.py @@ -126,7 +126,7 @@ def on_test_batch_end(self, outputs, batch, batch_idx): nmae = nmae.unsqueeze(0) tmp_csv = self.tmp_dir / f"metrics_batch_{self.global_rank}_{batch_idx}.csv" with open(tmp_csv, "w") as f: - for i, n in zip(indices, nmae, strict=False): + for i, n in zip(indices, nmae, strict=True): idx = i f.write(f"{idx},{n.item()}\n") From ac7952edfe16ffd76dde9a1b1fadcd32d8386ffc Mon Sep 17 00:00:00 2001 From: Hananeh Oliaei Date: Wed, 4 Feb 2026 14:50:17 -0500 Subject: [PATCH 05/14] handling file saving and deleting on rank 0 --- src/electrai/lightning.py | 65 +++++++++++++++++++++------------------ 1 file changed, 35 insertions(+), 30 deletions(-) diff --git a/src/electrai/lightning.py b/src/electrai/lightning.py index 5fe15825..4ba99d86 100644 --- a/src/electrai/lightning.py +++ b/src/electrai/lightning.py @@ -4,6 +4,7 @@ import numpy as np import torch +import torch.distributed as dist from lightning.pytorch import LightningModule from src.electrai.model.loss.charge import NormMAE from src.electrai.model.srgan_layernorm_pbc import GeneratorResNet @@ -110,39 +111,43 @@ def test_step(self, batch, batch_idx): } def on_test_batch_end(self, outputs, batch, batch_idx): - if self.out_dir is not None: - preds = outputs["pred"] - indices = outputs["index"] - nmae = outputs["nmae"] - - # Save prediction files - for i in range(len(indices)): - idx = indices[i] - np.save(self.out_dir / f"{idx}.npy", preds[i].squeeze(0).cpu().numpy()) - - if self.log_dir is not None: - # Save batch-level CSV - if isinstance(nmae, torch.Tensor) and nmae.ndim == 0: - nmae = nmae.unsqueeze(0) - tmp_csv = self.tmp_dir / f"metrics_batch_{self.global_rank}_{batch_idx}.csv" - with open(tmp_csv, "w") as f: - for i, n in zip(indices, nmae, strict=True): - idx = i - f.write(f"{idx},{n.item()}\n") + preds = outputs["pred"] + indices = outputs["index"] + nmae = outputs["nmae"] + + # Save prediction files + for i in range(len(indices)): + idx = indices[i] + np.save(self.out_dir / f"{idx}.npy", preds[i].squeeze(0).cpu().numpy()) + + # Save batch-level CSV + if isinstance(nmae, torch.Tensor) and nmae.ndim == 0: + nmae = nmae.unsqueeze(0) + tmp_csv = self.tmp_dir / f"metrics_batch_{self.global_rank}_{batch_idx}.csv" + with open(tmp_csv, "w") as f: + for i, n in zip(indices, nmae, strict=True): + idx = i + f.write(f"{idx},{n.item()}\n") def on_test_epoch_end(self): - if self.log_dir is None: - return + is_dist = dist.is_available() and dist.is_initialized() + rank = dist.get_rank() if is_dist else 0 - final_csv = self.log_dir / "metrics.csv" + if is_dist: + dist.barrier() - # gather all batch CSVs + final_csv = self.log_dir / "metrics.csv" all_tmp_csvs = sorted(self.tmp_dir.glob("metrics_batch_*.csv")) - # write final CSV with header - with open(final_csv, "w") as f_out: - f_out.write("index,nmae\n") - for tmp_csv in all_tmp_csvs: - with open(tmp_csv) as f_in: - for line in f_in: - f_out.write(line) + if rank == 0: + with open(final_csv, "w") as f_out: + f_out.write("index,nmae\n") + for tmp_csv in all_tmp_csvs: + with open(tmp_csv) as f_in: + for line in f_in: + f_out.write(line) + + self.tmp_dir.rmdir() + + if is_dist: + dist.barrier() From 5c8244a6d95fac3ac161bcd2d6eb48ededc51adf Mon Sep 17 00:00:00 2001 From: Hananeh Oliaei Date: Wed, 4 Feb 2026 17:42:29 -0500 Subject: [PATCH 06/14] file handling on rank 0 --- src/electrai/entrypoints/test.py | 15 ++++++++++----- src/electrai/lightning.py | 13 +++++-------- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/src/electrai/entrypoints/test.py b/src/electrai/entrypoints/test.py index d287ebd5..678a9d1f 100644 --- a/src/electrai/entrypoints/test.py +++ b/src/electrai/entrypoints/test.py @@ -40,6 +40,12 @@ def test(args): # ----------------------------- # Trainer # ----------------------------- + out_dir = Path(getattr(cfg, "out_dir", "predictions")) + log_dir = Path(getattr(cfg, "log_dir", "logs")) + tmp_dir = log_dir / "tmp" + for directory in [log_dir, out_dir, tmp_dir]: + directory.mkdir(exist_ok=True, parents=True) + trainer = Trainer( logger=None, callbacks=None, @@ -47,11 +53,10 @@ def test(args): devices=1, precision=cfg.model_precision, ) - out_dir = Path(getattr(cfg, "out_dir", "predictions")) - out_dir.mkdir(exist_ok=True, parents=True) - log_dir = Path(getattr(cfg, "log_dir", "logs")) - log_dir.mkdir(exist_ok=True, parents=True) - lit_model.test_cfg = SimpleNamespace(log_dir=log_dir, out_dir=out_dir) + + lit_model.test_cfg = SimpleNamespace( + log_dir=log_dir, out_dir=out_dir, tmp_dir=tmp_dir + ) # ----------------------------- # Train diff --git a/src/electrai/lightning.py b/src/electrai/lightning.py index 4ba99d86..22d40479 100644 --- a/src/electrai/lightning.py +++ b/src/electrai/lightning.py @@ -1,5 +1,6 @@ from __future__ import annotations +import shutil import time import numpy as np @@ -87,8 +88,7 @@ def configure_optimizers(self): def on_test_start(self): self.log_dir = self.test_cfg.log_dir self.out_dir = self.test_cfg.out_dir - self.tmp_dir = self.out_dir / "tmp" - self.tmp_dir.mkdir(exist_ok=True, parents=True) + self.tmp_dir = self.test_cfg.tmp_dir self.test_outputs = [] def test_step(self, batch, batch_idx): @@ -107,7 +107,7 @@ def test_step(self, batch, batch_idx): "target": y.detach().cpu(), "index": indices, "nmae": loss.detach().cpu(), - "time": time.time() - start_time, # + batch["load_time"][0], ??? + "time": time.time() - start_time, } def on_test_batch_end(self, outputs, batch, batch_idx): @@ -115,12 +115,10 @@ def on_test_batch_end(self, outputs, batch, batch_idx): indices = outputs["index"] nmae = outputs["nmae"] - # Save prediction files for i in range(len(indices)): idx = indices[i] np.save(self.out_dir / f"{idx}.npy", preds[i].squeeze(0).cpu().numpy()) - # Save batch-level CSV if isinstance(nmae, torch.Tensor) and nmae.ndim == 0: nmae = nmae.unsqueeze(0) tmp_csv = self.tmp_dir / f"metrics_batch_{self.global_rank}_{batch_idx}.csv" @@ -131,7 +129,6 @@ def on_test_batch_end(self, outputs, batch, batch_idx): def on_test_epoch_end(self): is_dist = dist.is_available() and dist.is_initialized() - rank = dist.get_rank() if is_dist else 0 if is_dist: dist.barrier() @@ -139,7 +136,7 @@ def on_test_epoch_end(self): final_csv = self.log_dir / "metrics.csv" all_tmp_csvs = sorted(self.tmp_dir.glob("metrics_batch_*.csv")) - if rank == 0: + if self.global_rank == 0: with open(final_csv, "w") as f_out: f_out.write("index,nmae\n") for tmp_csv in all_tmp_csvs: @@ -147,7 +144,7 @@ def on_test_epoch_end(self): for line in f_in: f_out.write(line) - self.tmp_dir.rmdir() + shutil.rmtree(self.tmp_dir, ignore_errors=True) if is_dist: dist.barrier() From 9cb1957a123e7c47e07551dc97c288f5a2248de2 Mon Sep 17 00:00:00 2001 From: Hananeh Oliaei Date: Wed, 4 Feb 2026 18:04:23 -0500 Subject: [PATCH 07/14] set up prediction saving using a configurable parameter: save_pred --- src/electrai/entrypoints/test.py | 11 +++++++---- src/electrai/lightning.py | 8 +++++--- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/src/electrai/entrypoints/test.py b/src/electrai/entrypoints/test.py index 678a9d1f..f9c2ab03 100644 --- a/src/electrai/entrypoints/test.py +++ b/src/electrai/entrypoints/test.py @@ -40,12 +40,15 @@ def test(args): # ----------------------------- # Trainer # ----------------------------- - out_dir = Path(getattr(cfg, "out_dir", "predictions")) + if cfg.save_pred: + out_dir = Path(getattr(cfg, "out_dir", "predictions")) + out_dir.mkdir(exist_ok=True, parents=True) + else: + out_dir = None log_dir = Path(getattr(cfg, "log_dir", "logs")) tmp_dir = log_dir / "tmp" - for directory in [log_dir, out_dir, tmp_dir]: + for directory in [log_dir, tmp_dir]: directory.mkdir(exist_ok=True, parents=True) - trainer = Trainer( logger=None, callbacks=None, @@ -55,7 +58,7 @@ def test(args): ) lit_model.test_cfg = SimpleNamespace( - log_dir=log_dir, out_dir=out_dir, tmp_dir=tmp_dir + log_dir=log_dir, out_dir=out_dir, tmp_dir=tmp_dir, save_pred=cfg.save_pred ) # ----------------------------- diff --git a/src/electrai/lightning.py b/src/electrai/lightning.py index 22d40479..e78eb0ea 100644 --- a/src/electrai/lightning.py +++ b/src/electrai/lightning.py @@ -89,6 +89,7 @@ def on_test_start(self): self.log_dir = self.test_cfg.log_dir self.out_dir = self.test_cfg.out_dir self.tmp_dir = self.test_cfg.tmp_dir + self.save_pred = self.test_cfg.save_pred self.test_outputs = [] def test_step(self, batch, batch_idx): @@ -115,9 +116,10 @@ def on_test_batch_end(self, outputs, batch, batch_idx): indices = outputs["index"] nmae = outputs["nmae"] - for i in range(len(indices)): - idx = indices[i] - np.save(self.out_dir / f"{idx}.npy", preds[i].squeeze(0).cpu().numpy()) + if self.save_pred: + for i in range(len(indices)): + idx = indices[i] + np.save(self.out_dir / f"{idx}.npy", preds[i].squeeze(0).cpu().numpy()) if isinstance(nmae, torch.Tensor) and nmae.ndim == 0: nmae = nmae.unsqueeze(0) From bbaed9f013ba3b6b86ac93a8b1db4c13b3c80054 Mon Sep 17 00:00:00 2001 From: Hananeh Oliaei Date: Wed, 4 Feb 2026 18:07:19 -0500 Subject: [PATCH 08/14] set up prediction saving using a configurable parameter: save_pred --- src/electrai/configs/MP/config.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/src/electrai/configs/MP/config.yaml b/src/electrai/configs/MP/config.yaml index 45e8e986..3da0a1a0 100644 --- a/src/electrai/configs/MP/config.yaml +++ b/src/electrai/configs/MP/config.yaml @@ -39,5 +39,6 @@ wb_pname: mp-experiment # checkpoints ckpt_path: ./checkpoints +save_pred: true log_dir: ./logs out_dir: ./predictions From 530d047092a69e40e4fe615578df8caed7a244bb Mon Sep 17 00:00:00 2001 From: Hananeh Oliaei Date: Fri, 6 Feb 2026 16:08:16 -0500 Subject: [PATCH 09/14] return predictions conditioned on configurable save_pred parameter --- src/electrai/lightning.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/electrai/lightning.py b/src/electrai/lightning.py index e78eb0ea..49575b4a 100644 --- a/src/electrai/lightning.py +++ b/src/electrai/lightning.py @@ -103,20 +103,22 @@ def test_step(self, batch, batch_idx): self.log("test_loss", loss, prog_bar=True, sync_dist=True) - return { - "pred": preds.detach().cpu(), + out = { "target": y.detach().cpu(), "index": indices, "nmae": loss.detach().cpu(), "time": time.time() - start_time, } + if self.save_pred: + out["pred"] = preds.detach().cpu() + return out def on_test_batch_end(self, outputs, batch, batch_idx): - preds = outputs["pred"] indices = outputs["index"] nmae = outputs["nmae"] if self.save_pred: + preds = outputs["pred"] for i in range(len(indices)): idx = indices[i] np.save(self.out_dir / f"{idx}.npy", preds[i].squeeze(0).cpu().numpy()) From a3baac30ace83f02acd382c89a7323968f208180 Mon Sep 17 00:00:00 2001 From: Hananeh Oliaei Date: Fri, 6 Feb 2026 16:19:45 -0500 Subject: [PATCH 10/14] removed redundant variable --- src/electrai/lightning.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/electrai/lightning.py b/src/electrai/lightning.py index 49575b4a..1995e238 100644 --- a/src/electrai/lightning.py +++ b/src/electrai/lightning.py @@ -121,14 +121,16 @@ def on_test_batch_end(self, outputs, batch, batch_idx): preds = outputs["pred"] for i in range(len(indices)): idx = indices[i] - np.save(self.out_dir / f"{idx}.npy", preds[i].squeeze(0).cpu().numpy()) + np.save( + self.out_dir / f"rank_{self.global_rank}_{idx}.npy", + preds[i].squeeze(0).cpu().numpy(), + ) if isinstance(nmae, torch.Tensor) and nmae.ndim == 0: nmae = nmae.unsqueeze(0) tmp_csv = self.tmp_dir / f"metrics_batch_{self.global_rank}_{batch_idx}.csv" with open(tmp_csv, "w") as f: - for i, n in zip(indices, nmae, strict=True): - idx = i + for idx, n in zip(indices, nmae, strict=True): f.write(f"{idx},{n.item()}\n") def on_test_epoch_end(self): From f05586ab841ada0ae3da655e78d40d61d7ed00aa Mon Sep 17 00:00:00 2001 From: Betsy Cannon Date: Fri, 6 Feb 2026 17:10:59 -0500 Subject: [PATCH 11/14] Confirm all temp files are visible --- src/electrai/lightning.py | 32 +++++++++++++++++++++++++++++--- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/src/electrai/lightning.py b/src/electrai/lightning.py index 1995e238..43c7a066 100644 --- a/src/electrai/lightning.py +++ b/src/electrai/lightning.py @@ -7,6 +7,7 @@ import torch import torch.distributed as dist from lightning.pytorch import LightningModule + from src.electrai.model.loss.charge import NormMAE from src.electrai.model.srgan_layernorm_pbc import GeneratorResNet @@ -136,13 +137,38 @@ def on_test_batch_end(self, outputs, batch, batch_idx): def on_test_epoch_end(self): is_dist = dist.is_available() and dist.is_initialized() + # Each rank counts how many tmp CSVs it wrote + local_count = len( + list(self.tmp_dir.glob(f"metrics_batch_{self.global_rank}_*.csv")) + ) + if is_dist: + # Sum file counts across all ranks so rank 0 knows the expected total + count_tensor = torch.tensor( + [local_count], dtype=torch.long, device=self.device + ) + dist.all_reduce(count_tensor, op=dist.ReduceOp.SUM) + expected_total = count_tensor.item() dist.barrier() - - final_csv = self.log_dir / "metrics.csv" - all_tmp_csvs = sorted(self.tmp_dir.glob("metrics_batch_*.csv")) + else: + expected_total = local_count if self.global_rank == 0: + final_csv = self.log_dir / "metrics.csv" + + # Retry glob until all files are visible (handles NFS caching) + all_tmp_csvs = sorted(self.tmp_dir.glob("metrics_batch_*.csv")) + retries = 0 + while len(all_tmp_csvs) < expected_total and retries < 30: + time.sleep(1) + all_tmp_csvs = sorted(self.tmp_dir.glob("metrics_batch_*.csv")) + retries += 1 + + if len(all_tmp_csvs) < expected_total: + raise RuntimeError( + f"Expected {expected_total} CSV files but found {len(all_tmp_csvs)}. Possible NFS caching issue." + ) + with open(final_csv, "w") as f_out: f_out.write("index,nmae\n") for tmp_csv in all_tmp_csvs: From 74b24712b6e71460972f14f62a30efed0e763ccd Mon Sep 17 00:00:00 2001 From: Hananeh Oliaei Date: Fri, 6 Feb 2026 18:46:49 -0500 Subject: [PATCH 12/14] log rank in the metrics file --- src/electrai/lightning.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/electrai/lightning.py b/src/electrai/lightning.py index 43c7a066..5f565644 100644 --- a/src/electrai/lightning.py +++ b/src/electrai/lightning.py @@ -7,7 +7,6 @@ import torch import torch.distributed as dist from lightning.pytorch import LightningModule - from src.electrai.model.loss.charge import NormMAE from src.electrai.model.srgan_layernorm_pbc import GeneratorResNet @@ -129,18 +128,18 @@ def on_test_batch_end(self, outputs, batch, batch_idx): if isinstance(nmae, torch.Tensor) and nmae.ndim == 0: nmae = nmae.unsqueeze(0) - tmp_csv = self.tmp_dir / f"metrics_batch_{self.global_rank}_{batch_idx}.csv" + tmp_csv = ( + self.tmp_dir / f"metrics_rank_{self.global_rank}_batch_{batch_idx}.csv" + ) with open(tmp_csv, "w") as f: for idx, n in zip(indices, nmae, strict=True): - f.write(f"{idx},{n.item()}\n") + f.write(f"rank_{self.global_rank},{idx},{n.item()}\n") def on_test_epoch_end(self): is_dist = dist.is_available() and dist.is_initialized() # Each rank counts how many tmp CSVs it wrote - local_count = len( - list(self.tmp_dir.glob(f"metrics_batch_{self.global_rank}_*.csv")) - ) + local_count = len(list(self.tmp_dir.glob("metrics_rank_*_batch_*.csv"))) if is_dist: # Sum file counts across all ranks so rank 0 knows the expected total @@ -153,15 +152,18 @@ def on_test_epoch_end(self): else: expected_total = local_count + final_csv = self.log_dir / "metrics.csv" + all_tmp_csvs = sorted(self.tmp_dir.glob("metrics_rank_*_batch_*.csv")) + if self.global_rank == 0: final_csv = self.log_dir / "metrics.csv" # Retry glob until all files are visible (handles NFS caching) - all_tmp_csvs = sorted(self.tmp_dir.glob("metrics_batch_*.csv")) + all_tmp_csvs = sorted(self.tmp_dir.glob("metrics_rank_*_batch_*.csv")) retries = 0 while len(all_tmp_csvs) < expected_total and retries < 30: time.sleep(1) - all_tmp_csvs = sorted(self.tmp_dir.glob("metrics_batch_*.csv")) + all_tmp_csvs = sorted(self.tmp_dir.glob("metrics_rank_*_batch_*.csv")) retries += 1 if len(all_tmp_csvs) < expected_total: @@ -170,7 +172,7 @@ def on_test_epoch_end(self): ) with open(final_csv, "w") as f_out: - f_out.write("index,nmae\n") + f_out.write("rank,index,nmae\n") for tmp_csv in all_tmp_csvs: with open(tmp_csv) as f_in: for line in f_in: From 103ac21687c7df217cea40edb7d30afb4d821082 Mon Sep 17 00:00:00 2001 From: Hananeh Oliaei Date: Fri, 6 Feb 2026 22:48:47 -0500 Subject: [PATCH 13/14] updated gitignore --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index c2ade997..b7920944 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,8 @@ # Docs .cache/ docs/reference/* +./examples/MP/experiments +./examples/QM9 *doctrees* /site From 510fddde0ff5e6e24046bc3c14868ca5bc9d8916 Mon Sep 17 00:00:00 2001 From: Hananeh Oliaei Date: Fri, 6 Feb 2026 22:55:16 -0500 Subject: [PATCH 14/14] updated lightning module --- src/electrai/lightning.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/src/electrai/lightning.py b/src/electrai/lightning.py index 5f565644..3a384cae 100644 --- a/src/electrai/lightning.py +++ b/src/electrai/lightning.py @@ -137,38 +137,34 @@ def on_test_batch_end(self, outputs, batch, batch_idx): def on_test_epoch_end(self): is_dist = dist.is_available() and dist.is_initialized() + rank = dist.get_rank() if is_dist else 0 - # Each rank counts how many tmp CSVs it wrote - local_count = len(list(self.tmp_dir.glob("metrics_rank_*_batch_*.csv"))) + # Count only files written by THIS rank + local_count = len(list(self.tmp_dir.glob(f"metrics_rank_{rank}_batch_*.csv"))) if is_dist: - # Sum file counts across all ranks so rank 0 knows the expected total count_tensor = torch.tensor( [local_count], dtype=torch.long, device=self.device ) dist.all_reduce(count_tensor, op=dist.ReduceOp.SUM) - expected_total = count_tensor.item() + expected_total = int(count_tensor.item()) dist.barrier() else: expected_total = local_count final_csv = self.log_dir / "metrics.csv" - all_tmp_csvs = sorted(self.tmp_dir.glob("metrics_rank_*_batch_*.csv")) if self.global_rank == 0: - final_csv = self.log_dir / "metrics.csv" - - # Retry glob until all files are visible (handles NFS caching) - all_tmp_csvs = sorted(self.tmp_dir.glob("metrics_rank_*_batch_*.csv")) retries = 0 - while len(all_tmp_csvs) < expected_total and retries < 30: + all_tmp_csvs = sorted(self.tmp_dir.glob("metrics_rank_*_batch_*.csv")) + while len(all_tmp_csvs) < expected_total and retries < 60: time.sleep(1) all_tmp_csvs = sorted(self.tmp_dir.glob("metrics_rank_*_batch_*.csv")) retries += 1 if len(all_tmp_csvs) < expected_total: raise RuntimeError( - f"Expected {expected_total} CSV files but found {len(all_tmp_csvs)}. Possible NFS caching issue." + f"Expected {expected_total} CSV files but found {len(all_tmp_csvs)}." ) with open(final_csv, "w") as f_out: