diff --git a/alphanet/eval.py b/alphanet/eval.py index 7c3ef3a..d31c548 100644 --- a/alphanet/eval.py +++ b/alphanet/eval.py @@ -142,6 +142,32 @@ def run_analysis(evaluator: Evaluator, loaders: dict, output_dir: str, quiet: bo # Ensure output directory exists Path(output_dir).mkdir(parents=True, exist_ok=True) + + split_sizes = { + "Train": len(loaders["train"].dataset), + "Validation": len(loaders["valid"].dataset), + "Test": len(loaders["test"].dataset), + } + + split_table = Table(title="Evaluation Splits", show_header=False, box=None) + split_table.add_column("Split", style="dim") + split_table.add_column("Samples") + for split_name, split_size in split_sizes.items(): + split_table.add_row(split_name, str(split_size)) + console.print(Panel.fit(split_table, border_style="dim")) + + if all(size == 0 for size in split_sizes.values()): + raise RuntimeError( + "All evaluation splits are empty. Check train_size/valid_size/test_size or dataset configuration." + ) + + empty_splits = [name for name, size in split_sizes.items() if size == 0] + if empty_splits: + console.print( + "[yellow]Warning:[/] Empty evaluation splits detected: " + + ", ".join(empty_splits) + + ". They will be skipped in parity plots." + ) console.log(f"Starting analysis: [cyan]{output_dir}[/]") diff --git a/alphanet/evaler.py b/alphanet/evaler.py index 95b1557..75e5ff2 100644 --- a/alphanet/evaler.py +++ b/alphanet/evaler.py @@ -1,7 +1,6 @@ import torch import os import matplotlib.pyplot as plt -from torch.autograd import grad from tqdm import tqdm from alphanet.models.model import AlphaNetWrapper class Evaluator: @@ -20,17 +19,39 @@ def load_model(self, model_path): self.model.load_state_dict(checkpoint) + @staticmethod + def _warn_skip(split_name, quantity, reason): + print(f"[alpha-eval] Skip {split_name} {quantity}: {reason}.") + + @staticmethod + def _get_diag_limits(pred_batches, target_batches, quantity): + if not pred_batches or not target_batches: + raise RuntimeError( + f"No valid {quantity} samples found across Train/Validation/Test splits." + ) + + combined = torch.cat( + [ + torch.cat(target_batches, dim=0).reshape(-1), + torch.cat(pred_batches, dim=0).reshape(-1), + ], + dim=0, + ) + return combined.min().item(), combined.max().item() + def plot_energy_parity(self, train_loader, val_loader, test_loader, plots_dir=None, disable=False): datasets = {'Train': train_loader, 'Validation': val_loader, 'Test': test_loader} colors = {'Train': '#1f77b4', 'Validation': '#ff7f0e', 'Test': '#2ca02c'} + all_preds_energy = [] + all_targets_energy = [] plt.figure(figsize=(12, 8)) plt.grid(True, linestyle='--', alpha=0.5) for name, loader in datasets.items(): - preds_energy = torch.Tensor([])#.to(self.device) - targets_energy = torch.Tensor([])#.to(self.device) + preds_energy_batches = [] + targets_energy_batches = [] for batch_data in tqdm(loader, disable=disable): batch_data = batch_data.to(self.device) @@ -43,8 +64,19 @@ def plot_energy_parity(self, train_loader, val_loader, test_loader, plots_dir=No energy, _, _ = model_outputs energy = energy.squeeze() - preds_energy = torch.cat([preds_energy.cpu(), (energy / batch_data.natoms).detach().cpu()], dim=0) - targets_energy = torch.cat([targets_energy.cpu(), batch_data.y.cpu() / batch_data.natoms.cpu()], dim=0) + preds_energy_batches.append((energy / batch_data.natoms).detach().cpu().reshape(-1)) + targets_energy_batches.append((batch_data.y.cpu() / batch_data.natoms.cpu()).reshape(-1)) + + if not preds_energy_batches or not targets_energy_batches: + self._warn_skip(name, "energy parity", "empty dataset") + continue + + preds_energy = torch.cat(preds_energy_batches, dim=0) + targets_energy = torch.cat(targets_energy_batches, dim=0) + + if preds_energy.numel() == 0 or targets_energy.numel() == 0: + self._warn_skip(name, "energy parity", "no collected samples") + continue deviation = torch.abs(preds_energy - targets_energy) threshold = 100 * torch.sqrt(torch.mean((preds_energy - targets_energy) ** 2)).item() @@ -53,6 +85,13 @@ def plot_energy_parity(self, train_loader, val_loader, test_loader, plots_dir=No mask = mask_deviation & mask_nan preds_energy_filtered = preds_energy[mask] targets_energy_filtered = targets_energy[mask] + + if preds_energy_filtered.numel() == 0 or targets_energy_filtered.numel() == 0: + self._warn_skip(name, "energy parity", "no valid samples after filtering") + continue + + all_preds_energy.append(preds_energy_filtered) + all_targets_energy.append(targets_energy_filtered) energy_mae_filtered = torch.mean(torch.abs(preds_energy_filtered - targets_energy_filtered)).item() energy_rmse_filtered = torch.sqrt(torch.mean((preds_energy_filtered - targets_energy_filtered) ** 2)).item() plt.scatter( @@ -64,8 +103,11 @@ def plot_energy_parity(self, train_loader, val_loader, test_loader, plots_dir=No s=20 ) - min_energy = targets_energy.min().cpu().numpy() - max_energy = targets_energy.max().cpu().numpy() + min_energy, max_energy = self._get_diag_limits( + all_preds_energy, + all_targets_energy, + "energy", + ) plt.plot([min_energy, max_energy], [min_energy, max_energy], 'k--', lw=2, label='Ideal') plt.xlabel('True Energy per Atom', fontsize=17) plt.ylabel('Predicted Energy per Atom', fontsize=17) @@ -82,13 +124,15 @@ def plot_energy_parity(self, train_loader, val_loader, test_loader, plots_dir=No def plot_force_parity(self, train_loader, val_loader, test_loader, plots_dir=None, disable=False): datasets = {'Train': train_loader, 'Validation': val_loader, 'Test': test_loader} colors = {'Train': '#1f77b4', 'Validation': '#ff7f0e', 'Test': '#2ca02c'} + all_preds_force = [] + all_targets_force = [] plt.figure(figsize=(12, 8)) plt.grid(True, linestyle='--', alpha=0.5) for name, loader in datasets.items(): - preds_force = torch.Tensor([])#.to(self.device) - targets_force = torch.Tensor([])#.to(self.device) + preds_force_batches = [] + targets_force_batches = [] for batch_data in tqdm(loader, disable=disable): batch_data = batch_data.to(self.device) @@ -101,18 +145,40 @@ def plot_force_parity(self, train_loader, val_loader, test_loader, plots_dir=Non _, force,_ = model_outputs if torch.sum(torch.isnan(force)) != 0: - mask = ~torch.isnan(force) + mask = ~torch.isnan(force) & ~torch.isnan(batch_data.force) force = force[mask].reshape((-1, 3)) - batch_data.force = batch_data.force[mask].reshape((-1, 3)) + target_force = batch_data.force[mask].reshape((-1, 3)) + else: + target_force = batch_data.force - preds_force = torch.cat([preds_force.cpu(), force.detach().cpu()], dim=0) - targets_force = torch.cat([targets_force.cpu(), batch_data.force.cpu()], dim=0) + preds_force_batches.append(force.detach().cpu()) + targets_force_batches.append(target_force.cpu()) + + if not preds_force_batches or not targets_force_batches: + self._warn_skip(name, "force parity", "empty dataset") + continue + + preds_force = torch.cat(preds_force_batches, dim=0) + targets_force = torch.cat(targets_force_batches, dim=0) + + if preds_force.numel() == 0 or targets_force.numel() == 0: + self._warn_skip(name, "force parity", "no collected samples") + continue deviation = torch.abs(preds_force - targets_force) threshold = 100 * torch.sqrt(torch.mean((preds_force - targets_force) ** 2)).item() - mask = deviation < threshold + mask_deviation = deviation < threshold + mask_nan = ~torch.isnan(preds_force) & ~torch.isnan(targets_force) + mask = mask_deviation & mask_nan preds_force_filtered = preds_force[mask] targets_force_filtered = targets_force[mask] + + if preds_force_filtered.numel() == 0 or targets_force_filtered.numel() == 0: + self._warn_skip(name, "force parity", "no valid samples after filtering") + continue + + all_preds_force.append(preds_force_filtered.reshape(-1)) + all_targets_force.append(targets_force_filtered.reshape(-1)) force_mae_filtered = torch.mean(torch.abs(preds_force_filtered - targets_force_filtered)).item() force_rmse_filtered = torch.sqrt(torch.mean((preds_force_filtered - targets_force_filtered) ** 2)).item() @@ -125,8 +191,11 @@ def plot_force_parity(self, train_loader, val_loader, test_loader, plots_dir=Non s=20 ) - min_force = targets_force.min().cpu().numpy() - max_force = targets_force.max().cpu().numpy() + min_force, max_force = self._get_diag_limits( + all_preds_force, + all_targets_force, + "force", + ) plt.plot([min_force, max_force], [min_force, max_force], 'k--', lw=2, label='Ideal') plt.xlabel('True Force', fontsize=17) @@ -154,4 +223,3 @@ def evaluate(self, data_path, plots_dir=None, disable=False): evaluator = Evaluator(model_path, device) evaluator.evaluate(data_path, plots_dir) - diff --git a/alphanet/infer/calc.py b/alphanet/infer/calc.py index cb68752..adcc555 100644 --- a/alphanet/infer/calc.py +++ b/alphanet/infer/calc.py @@ -4,6 +4,8 @@ from alphanet.models.graph import build_neighbor_topology, graph_from_neighbor_topology from alphanet.models.model import AlphaNetWrapper +torch.set_float32_matmul_precision('high') + class AlphaNetCalculator(Calculator): """ ASE Calculator for AlphaNet models. diff --git a/alphanet/train.py b/alphanet/train.py index 06fd768..63d5351 100644 --- a/alphanet/train.py +++ b/alphanet/train.py @@ -5,6 +5,8 @@ from alphanet.models.model import AlphaNetWrapper from alphanet.mul_trainer import Trainer +torch.set_float32_matmul_precision('high') + def run_training(config1, runtime_config): train_dataset, valid_dataset, test_dataset = get_pic_datasets( @@ -90,4 +92,4 @@ def run_training(config1, runtime_config): if runtime_config["resume"] and ckpt_path_arg: print(f"🔄 Resuming training from checkpoint: {ckpt_path_arg}") - trainer.fit(pl_module, ckpt_path=ckpt_path_arg) \ No newline at end of file + trainer.fit(pl_module, ckpt_path=ckpt_path_arg)