Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions alphanet/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,32 @@ def run_analysis(evaluator: Evaluator, loaders: dict, output_dir: str, quiet: bo

# Ensure output directory exists
Path(output_dir).mkdir(parents=True, exist_ok=True)

split_sizes = {
"Train": len(loaders["train"].dataset),
"Validation": len(loaders["valid"].dataset),
"Test": len(loaders["test"].dataset),
}

split_table = Table(title="Evaluation Splits", show_header=False, box=None)
split_table.add_column("Split", style="dim")
split_table.add_column("Samples")
for split_name, split_size in split_sizes.items():
split_table.add_row(split_name, str(split_size))
console.print(Panel.fit(split_table, border_style="dim"))

if all(size == 0 for size in split_sizes.values()):
raise RuntimeError(
"All evaluation splits are empty. Check train_size/valid_size/test_size or dataset configuration."
)

empty_splits = [name for name, size in split_sizes.items() if size == 0]
if empty_splits:
console.print(
"[yellow]Warning:[/] Empty evaluation splits detected: "
+ ", ".join(empty_splits)
+ ". They will be skipped in parity plots."
)

console.log(f"Starting analysis: [cyan]{output_dir}[/]")

Expand Down
102 changes: 85 additions & 17 deletions alphanet/evaler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch
import os
import matplotlib.pyplot as plt
from torch.autograd import grad
from tqdm import tqdm
from alphanet.models.model import AlphaNetWrapper
class Evaluator:
Expand All @@ -20,17 +19,39 @@ def load_model(self, model_path):

self.model.load_state_dict(checkpoint)

@staticmethod
def _warn_skip(split_name, quantity, reason):
print(f"[alpha-eval] Skip {split_name} {quantity}: {reason}.")

@staticmethod
def _get_diag_limits(pred_batches, target_batches, quantity):
if not pred_batches or not target_batches:
raise RuntimeError(
f"No valid {quantity} samples found across Train/Validation/Test splits."
)

combined = torch.cat(
[
torch.cat(target_batches, dim=0).reshape(-1),
torch.cat(pred_batches, dim=0).reshape(-1),
],
dim=0,
)
return combined.min().item(), combined.max().item()


def plot_energy_parity(self, train_loader, val_loader, test_loader, plots_dir=None, disable=False):
datasets = {'Train': train_loader, 'Validation': val_loader, 'Test': test_loader}
colors = {'Train': '#1f77b4', 'Validation': '#ff7f0e', 'Test': '#2ca02c'}
all_preds_energy = []
all_targets_energy = []

plt.figure(figsize=(12, 8))
plt.grid(True, linestyle='--', alpha=0.5)

for name, loader in datasets.items():
preds_energy = torch.Tensor([])#.to(self.device)
targets_energy = torch.Tensor([])#.to(self.device)
preds_energy_batches = []
targets_energy_batches = []

for batch_data in tqdm(loader, disable=disable):
batch_data = batch_data.to(self.device)
Expand All @@ -43,8 +64,19 @@ def plot_energy_parity(self, train_loader, val_loader, test_loader, plots_dir=No
energy, _, _ = model_outputs

energy = energy.squeeze()
preds_energy = torch.cat([preds_energy.cpu(), (energy / batch_data.natoms).detach().cpu()], dim=0)
targets_energy = torch.cat([targets_energy.cpu(), batch_data.y.cpu() / batch_data.natoms.cpu()], dim=0)
preds_energy_batches.append((energy / batch_data.natoms).detach().cpu().reshape(-1))
targets_energy_batches.append((batch_data.y.cpu() / batch_data.natoms.cpu()).reshape(-1))

if not preds_energy_batches or not targets_energy_batches:
self._warn_skip(name, "energy parity", "empty dataset")
continue

preds_energy = torch.cat(preds_energy_batches, dim=0)
targets_energy = torch.cat(targets_energy_batches, dim=0)

if preds_energy.numel() == 0 or targets_energy.numel() == 0:
self._warn_skip(name, "energy parity", "no collected samples")
continue

deviation = torch.abs(preds_energy - targets_energy)
threshold = 100 * torch.sqrt(torch.mean((preds_energy - targets_energy) ** 2)).item()
Expand All @@ -53,6 +85,13 @@ def plot_energy_parity(self, train_loader, val_loader, test_loader, plots_dir=No
mask = mask_deviation & mask_nan
preds_energy_filtered = preds_energy[mask]
targets_energy_filtered = targets_energy[mask]

if preds_energy_filtered.numel() == 0 or targets_energy_filtered.numel() == 0:
self._warn_skip(name, "energy parity", "no valid samples after filtering")
continue

all_preds_energy.append(preds_energy_filtered)
all_targets_energy.append(targets_energy_filtered)
energy_mae_filtered = torch.mean(torch.abs(preds_energy_filtered - targets_energy_filtered)).item()
energy_rmse_filtered = torch.sqrt(torch.mean((preds_energy_filtered - targets_energy_filtered) ** 2)).item()
plt.scatter(
Expand All @@ -64,8 +103,11 @@ def plot_energy_parity(self, train_loader, val_loader, test_loader, plots_dir=No
s=20
)

min_energy = targets_energy.min().cpu().numpy()
max_energy = targets_energy.max().cpu().numpy()
min_energy, max_energy = self._get_diag_limits(
all_preds_energy,
all_targets_energy,
"energy",
)
plt.plot([min_energy, max_energy], [min_energy, max_energy], 'k--', lw=2, label='Ideal')
plt.xlabel('True Energy per Atom', fontsize=17)
plt.ylabel('Predicted Energy per Atom', fontsize=17)
Expand All @@ -82,13 +124,15 @@ def plot_energy_parity(self, train_loader, val_loader, test_loader, plots_dir=No
def plot_force_parity(self, train_loader, val_loader, test_loader, plots_dir=None, disable=False):
datasets = {'Train': train_loader, 'Validation': val_loader, 'Test': test_loader}
colors = {'Train': '#1f77b4', 'Validation': '#ff7f0e', 'Test': '#2ca02c'}
all_preds_force = []
all_targets_force = []

plt.figure(figsize=(12, 8))
plt.grid(True, linestyle='--', alpha=0.5)

for name, loader in datasets.items():
preds_force = torch.Tensor([])#.to(self.device)
targets_force = torch.Tensor([])#.to(self.device)
preds_force_batches = []
targets_force_batches = []

for batch_data in tqdm(loader, disable=disable):
batch_data = batch_data.to(self.device)
Expand All @@ -101,18 +145,40 @@ def plot_force_parity(self, train_loader, val_loader, test_loader, plots_dir=Non
_, force,_ = model_outputs

if torch.sum(torch.isnan(force)) != 0:
mask = ~torch.isnan(force)
mask = ~torch.isnan(force) & ~torch.isnan(batch_data.force)
force = force[mask].reshape((-1, 3))
batch_data.force = batch_data.force[mask].reshape((-1, 3))
target_force = batch_data.force[mask].reshape((-1, 3))
else:
target_force = batch_data.force

preds_force = torch.cat([preds_force.cpu(), force.detach().cpu()], dim=0)
targets_force = torch.cat([targets_force.cpu(), batch_data.force.cpu()], dim=0)
preds_force_batches.append(force.detach().cpu())
targets_force_batches.append(target_force.cpu())

if not preds_force_batches or not targets_force_batches:
self._warn_skip(name, "force parity", "empty dataset")
continue

preds_force = torch.cat(preds_force_batches, dim=0)
targets_force = torch.cat(targets_force_batches, dim=0)

if preds_force.numel() == 0 or targets_force.numel() == 0:
self._warn_skip(name, "force parity", "no collected samples")
continue

deviation = torch.abs(preds_force - targets_force)
threshold = 100 * torch.sqrt(torch.mean((preds_force - targets_force) ** 2)).item()
mask = deviation < threshold
mask_deviation = deviation < threshold
mask_nan = ~torch.isnan(preds_force) & ~torch.isnan(targets_force)
mask = mask_deviation & mask_nan
preds_force_filtered = preds_force[mask]
targets_force_filtered = targets_force[mask]

if preds_force_filtered.numel() == 0 or targets_force_filtered.numel() == 0:
self._warn_skip(name, "force parity", "no valid samples after filtering")
continue

all_preds_force.append(preds_force_filtered.reshape(-1))
all_targets_force.append(targets_force_filtered.reshape(-1))
force_mae_filtered = torch.mean(torch.abs(preds_force_filtered - targets_force_filtered)).item()
force_rmse_filtered = torch.sqrt(torch.mean((preds_force_filtered - targets_force_filtered) ** 2)).item()

Expand All @@ -125,8 +191,11 @@ def plot_force_parity(self, train_loader, val_loader, test_loader, plots_dir=Non
s=20
)

min_force = targets_force.min().cpu().numpy()
max_force = targets_force.max().cpu().numpy()
min_force, max_force = self._get_diag_limits(
all_preds_force,
all_targets_force,
"force",
)
plt.plot([min_force, max_force], [min_force, max_force], 'k--', lw=2, label='Ideal')

plt.xlabel('True Force', fontsize=17)
Expand Down Expand Up @@ -154,4 +223,3 @@ def evaluate(self, data_path, plots_dir=None, disable=False):

evaluator = Evaluator(model_path, device)
evaluator.evaluate(data_path, plots_dir)

2 changes: 2 additions & 0 deletions alphanet/infer/calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from alphanet.models.graph import build_neighbor_topology, graph_from_neighbor_topology
from alphanet.models.model import AlphaNetWrapper

torch.set_float32_matmul_precision('high')

class AlphaNetCalculator(Calculator):
"""
ASE Calculator for AlphaNet models.
Expand Down
4 changes: 3 additions & 1 deletion alphanet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from alphanet.models.model import AlphaNetWrapper
from alphanet.mul_trainer import Trainer

torch.set_float32_matmul_precision('high')

def run_training(config1, runtime_config):

train_dataset, valid_dataset, test_dataset = get_pic_datasets(
Expand Down Expand Up @@ -90,4 +92,4 @@ def run_training(config1, runtime_config):
if runtime_config["resume"] and ckpt_path_arg:
print(f"🔄 Resuming training from checkpoint: {ckpt_path_arg}")

trainer.fit(pl_module, ckpt_path=ckpt_path_arg)
trainer.fit(pl_module, ckpt_path=ckpt_path_arg)