diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..bfa85c9 --- /dev/null +++ b/.gitignore @@ -0,0 +1,24 @@ +# --- .gitignore 内容开始 --- + +# 1. 忽略 LAMMPS 日志和轨迹文件 +log.lammps +*.log +*.lammpstrj + +# 2. 忽略数据结构文件 +*.data +POSCAR +CH4.txt +*pt +*.pt +*.pkl +# 3. 忽略大模型权重文件 (通常不建议传大文件到 git,除非你需要) + +# 4. 忽略 Python 编译缓存和打包文件 +__pycache__/ +*.egg-info/ +build/ +dist/ +*.zip +lammps/ +# --- .gitignore 内容结束 --- diff --git a/README.md b/README.md index a32fea3..0bb0e6e 100644 --- a/README.md +++ b/README.md @@ -6,17 +6,14 @@ We present **AlphaNet**, a local frame-based equivariant model designed to tackle the challenges of achieving both accurate and efficient simulations for atomistic systems. **AlphaNet** enhances computational efficiency and accuracy by leveraging the local geometric structures of atomic environments through the construction of equivariant local frames and learnable frame transitions. And inspired by Quantum Mechanics, AlphaNet **introduces efficient multi-body message passing by using contraction of matrix product states** rather than common 2-body message passing. Notably, AlphaNet offers one of the best trade-offs between computational efficiency and accuracy among existing models. Moreover, AlphaNet exhibits scalability across a broad spectrum of system and dataset sizes, affirming its versatility. markdown -## Update Log (v0.1.2) +## Update Log (v0.1.2-beta) ### Major Changes -1. **Added new 2 pretrained models** - - Provide a pretrained model for materials: **AlphaNet-MATPES-r2scan** and our first pretrained model for catlysis: **AlphaNet-AQCAT25**, see them in the [pretrained](./pretrained) folder. - - Users can **convert the checkpoint trained in torch to our JAX model** - -2. **Fixed some bugs** - - Support non-periodic boundary conditions in our ase calculator. - - Fixed errors in float64 +1. **Add lammps mliap interface** +2. **Slight change of model arch** +3. **Add finetune option** + ## Installation Guide @@ -84,7 +81,11 @@ alpha-train example.json # use --help to see more functions, like multi-gpu trai ```bash alpha-conv -i in.ckpt -o out.ckpt # use --help to see more functions ``` -3. Evaluate a model and draw diagonal plot: +2. Finetune a converted ckpt: +```bash +alpha-train example.json --finetune /path/to/your.ckpt +``` +4. Evaluate a model and draw diagonal plot: ```bash alpha-eval -c example.json -m /path/to/ckpt # use --help to see more functions ``` @@ -142,67 +143,17 @@ print(atoms.get_potential_energy()) ``` -### Using AlphaNet in JAX -1. Installation - ```bash - pip install -U --pre jax jaxlib "jax-cuda12-plugin[with-cuda]" jax-cuda12-pjrt -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ - ``` - This is just for reference. JAX installation may be tricky, please get more information in [JAX](https://docs.jax.dev/en/latest/installation.html) and its github issues. - - Currently I suggest **version>=0.4 <=0.4.10 or >=0.4.30 <=0.5 or ==0.6.2** - - Install flax and haiku - ```bash - pip install matscipy - pip install flax - pip install -U dm-haiku - ``` - -2. Converted checkpoints: - - See pretrained directory - -3. Convert a self-trained ckpt - - First from torch to flax: - ```bash - python scripts/conv_pt2flax.py #need to modify the path in it. - ``` - Then from flax to haiku: - - ```bash - python scripts/flax2haiku.py #need to modify the path in it. - ``` - -4. Performance: - - The output (energy forces stress) difference from torch model would below 0.001. I ran speed tests on a 4090 GPU, system size from 4 to 300, and get a **2.5x to 3x** speed up. - - Please note jax model need to be compiled first, so the first run could take a few seconds or minutes, but would be pretty fast after that. - -## Dataset Download - -[The Defected Bilayer Graphene Dataset](https://zenodo.org/records/10374206) - -[The Formate Decomposition on Cu Dataset](https://archive.materialscloud.org/record/2022.45) - -[The Zeolite Dataset](https://doi.org/10.6084/m9.figshare.27800211) - -[The OC dataset](https://opencatalystproject.org/) - -[The MPtrj dataset](https://matbench-discovery.materialsproject.org/data) - ## Pretrained Models -Current pretrained models: +Current pretrained models (due to the arch changes, previous pretrained models would need update, which will be done asap): For materials: -- [AlphaNet-MPtrj-v1](pretrained/MPtrj): A model trained on the MpTrj dataset. -- [AlphaNet-oma-v1](pretrained/OMA): A model trained on the OMAT24 dataset, and finetuned on sALEX+MPtrj. -- [AlphaNet-MATPES-r2scan](pretrained/MATPES): A model trained on the MATPES-r2scan dataset. -For surfaces adsorbtion and reactions: -- [AlphaNet-AQCAT25](pretrained/AQCAT25): A model trained on the AQCAT25 dataset. +- [AlphaNet-oma-v1.5](pretrained/OMA): A model trained on the OMAT24 dataset, and finetuned on sALEX+MPtrj. + +## Use AlphaNet in LAMMPS + +See [mliap_lammps](mliap_lammps.md) ## License @@ -222,3 +173,6 @@ We thank all contributors and the community for their support. Please open an is + + + diff --git a/alphanet/cli.py b/alphanet/cli.py index 22f0253..d0d3cbd 100644 --- a/alphanet/cli.py +++ b/alphanet/cli.py @@ -57,7 +57,8 @@ def display_config_table(main_config, runtime_config): @click.option("--num_devices", type=int, default=1, help="GPUs per node") @click.option("--resume", is_flag=True, help="Resume training from checkpoint") @click.option("--ckpt_path", type=click.Path(), default=None, help="Path to checkpoint file") -def main(config, num_nodes, num_devices, resume, ckpt_path): +@click.option("--finetune", type=click.Path(exists=True), default=None, help="Path to pretrained checkpoint for finetuning (resets optimizer)") +def main(config, num_nodes, num_devices, resume, ckpt_path, finetune): with open(config, "r") as f: mconfig = json.load(f) @@ -67,7 +68,8 @@ def main(config, num_nodes, num_devices, resume, ckpt_path): "num_nodes": num_nodes, "num_devices": num_devices, "resume": resume, - "ckpt_path": ckpt_path + "ckpt_path": ckpt_path, + "finetune_path": finetune } display_header() diff --git a/alphanet/config.py b/alphanet/config.py index 5bf9354..5fd8821 100644 --- a/alphanet/config.py +++ b/alphanet/config.py @@ -2,13 +2,16 @@ import subprocess import json -import torch +#import torch from typing import Literal, Dict, Optional from pydantic_settings import BaseSettings try: VERSION = ( - subprocess.check_output(["git", "rev-parse", "HEAD"]).decode().strip() + subprocess.check_output( + ["git", "rev-parse", "HEAD"], + stderr=subprocess.DEVNULL, + ).decode().strip() ) except Exception: VERSION = "NA" @@ -22,6 +25,7 @@ class TrainConfig(BaseSettings): batch_size: int = 32 vt_batch_size: int = 32 lr: float = 0.0005 + optimizer: str = "radam" lr_decay_factor: float = 0.5 lr_decay_step_size: int = 150 weight_decay: float = 0 @@ -86,7 +90,13 @@ class AlphaConfig(BaseSettings): has_norm_after_flag: bool = False reduce_mode: str = "sum" zbl: bool = False - device: torch.device = torch.device('cuda') if torch.cuda.is_available() else torch.device("cpu") + zbl_w: Optional[list] = [0.187,0.3769,0.189,0.081,0.003,0.037,0.0546,0.0715] + zbl_b: Optional[list] = [3.20,1.10,0.102,0.958,1.28,1.14,1.69,5] + zbl_gamma: float = 1.001 + zbl_alpha: float = 0.6032 + zbl_E2: float = 14.399645478425 + zbl_A0: float = 0.529177210903 + device: str = "cuda" diff --git a/alphanet/create_lammps_model.py b/alphanet/create_lammps_model.py new file mode 100644 index 0000000..7d29be3 --- /dev/null +++ b/alphanet/create_lammps_model.py @@ -0,0 +1,91 @@ +import argparse +import os +import torch +from pathlib import Path + +# Import the AlphaNet model wrapper and config +from alphanet.models.model import AlphaNetWrapper +from alphanet.config import All_Config + +# Import the Python-level LAMMPS interface class +try: + from alphanet.infer.lammps_mliap_alphanet import LAMMPS_MLIAP_ALPHANET +except ImportError: + print("Could not import LAMMPS_MLIAP_ALPHANET.") + print("Please ensure 'alphanet/infer/lammps_mliap_alphanet.py' exists.") + exit(1) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Convert an AlphaNet model to LAMMPS ML-IAP format (Python Pickle)", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--config", "-c", required=True, type=str, + help="Path to the model configuration JSON file", + ) + parser.add_argument( + "--checkpoint", "-m", required=True, type=str, + help="Path to the trained model checkpoint (.ckpt)", + ) + parser.add_argument( + "--output", "-o", required=True, type=str, + help="Output path to save the model (e.g., alphanet_lammps.pt)", + ) + parser.add_argument( + "--device", type=str, default="cpu", + help="Device to load the model on ('cpu' or 'cuda')", + ) + parser.add_argument( + "--dtype", type=str, default="float64", + choices=["float32", "float64"], + help="Data type for the model", + ) + return parser.parse_args() + +def main(): + args = parse_args() + + device = torch.device(args.device) + + print(f"1. Loading configuration from {args.config}...") + config_obj = All_Config().from_json(args.config) + + config_obj.model.dtype = "64" if args.dtype == "float64" else "32" + + print(f"2. Initializing AlphaNetWrapper (precision: {args.dtype}, device: {args.device})...") + model_wrapper = AlphaNetWrapper(config_obj.model) + + print(f"3. Loading weights from {args.checkpoint}...") + ckpt = torch.load(args.checkpoint, map_location=device) + + if 'state_dict' in ckpt: + state_dict = {k.replace('model.', ''): v for k, v in ckpt['state_dict'].items()} + model_wrapper.model.load_state_dict(state_dict, strict=False) + else: + model_wrapper.load_state_dict(ckpt, strict=False) + + if args.dtype == "float64": + model_wrapper.double() + else: + model_wrapper.float() + + model_wrapper.to(device).eval() + + print("4. Creating LAMMPS ML-IAP Interface Object...") + lammps_interface_object = LAMMPS_MLIAP_ALPHANET(model_wrapper) + + if device.type == 'cuda': + lammps_interface_object.model.cuda() + + print(f"5. Saving Python object to {args.output}...") + # Using standard torch.save for Python pickle compatibility + torch.save(lammps_interface_object, args.output) + + print("\n--- Success ---") + print(f"Created LAMMPS model file: {args.output}") + print("Usage in LAMMPS: pair_style mliap model/python ...") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/alphanet/data.py b/alphanet/data.py index 9ccaae3..6b80cd9 100644 --- a/alphanet/data.py +++ b/alphanet/data.py @@ -4,6 +4,7 @@ from tqdm import tqdm import torch from sklearn.utils import shuffle +from sklearn.model_selection import train_test_split import joblib from torch_geometric.data import Data, DataLoader, InMemoryDataset, download_url, extract_zip @@ -26,7 +27,7 @@ def get_pic_datasets(root, name, config): test_dataset = test_dataset[test_indices] else: - dataset = CustomPickleDataset(name=name, root=root) + dataset = CustomPickleDataset(name=name, root=root, config=config) split_idx = dataset.get_idx_split(len(dataset.data.y), train_size=config.train_size, valid_size=config.valid_size, test_size=config.test_size, seed=config.seed) @@ -120,8 +121,9 @@ def process(self): print('Saving...') torch.save((data, slices), self.processed_paths[0]) - def get_idx_split(self, data_size, train_size=None, valid_size=None, seed=None): - ids = shuffle(list(range(data_size))) + #def get_idx_split(self, data_size, train_size=None, valid_size=None, seed=None): + def get_idx_split(self, data_size, train_size=None, valid_size=None, test_size=None, seed=None): + ids = shuffle(list(range(data_size)), random_state=seed) if train_size is not None and valid_size is None: train_idx = ids[:train_size] diff --git a/alphanet/evaler.py b/alphanet/evaler.py index 7dc10fa..95b1557 100644 --- a/alphanet/evaler.py +++ b/alphanet/evaler.py @@ -113,8 +113,8 @@ def plot_force_parity(self, train_loader, val_loader, test_loader, plots_dir=Non mask = deviation < threshold preds_force_filtered = preds_force[mask] targets_force_filtered = targets_force[mask] - force_mae_filtered = 0.5*torch.mean(torch.abs(preds_force_filtered - targets_force_filtered)).item() - force_rmse_filtered = 0.5*torch.sqrt(torch.mean((preds_force_filtered - targets_force_filtered) ** 2)).item() + force_mae_filtered = torch.mean(torch.abs(preds_force_filtered - targets_force_filtered)).item() + force_rmse_filtered = torch.sqrt(torch.mean((preds_force_filtered - targets_force_filtered) ** 2)).item() plt.scatter( targets_force_filtered.cpu().numpy(), diff --git a/alphanet/infer/calc.py b/alphanet/infer/calc.py index 35def11..cb68752 100644 --- a/alphanet/infer/calc.py +++ b/alphanet/infer/calc.py @@ -1,7 +1,7 @@ -import torch import numpy as np +import torch from ase.calculators.calculator import Calculator, all_changes -from ase.data import atomic_numbers +from alphanet.models.graph import build_neighbor_topology, graph_from_neighbor_topology from alphanet.models.model import AlphaNetWrapper class AlphaNetCalculator(Calculator): @@ -15,7 +15,16 @@ class AlphaNetCalculator(Calculator): """ implemented_properties = ['energy', 'free_energy', 'forces', 'stress'] - def __init__(self, ckpt_path, config, device='cpu', precision='32', **kwargs): + def __init__( + self, + ckpt_path, + config, + device='cpu', + precision='32', + reuse_neighbors=True, + skin=0.5, + **kwargs, + ): """ Initializes the AlphaNetCalculator. @@ -24,11 +33,15 @@ def __init__(self, ckpt_path, config, device='cpu', precision='32', **kwargs): config (object): Model configuration object. device (str): Device to run the model on ('cpu' or 'cuda'). precision (str): Precision for calculations ('32' for float, '64' for double). + reuse_neighbors (bool): Whether to cache and reuse the neighbor list. + skin (float): Skin distance (Angstrom) for neighbor list caching. **kwargs: Additional arguments for the base ASE Calculator. """ Calculator.__init__(self, **kwargs) # --- Model Loading --- + if precision == "64": + config.dtype = '64' if ckpt_path.endswith('ckpt'): self.model = AlphaNetWrapper(config).to(torch.device(device)) # Load state dict, ignoring mismatches if any @@ -42,11 +55,101 @@ def __init__(self, ckpt_path, config, device='cpu', precision='32', **kwargs): self.precision = torch.float32 if precision == "32" else torch.float64 if precision == "64": - self.model.double() + self.model.double() self.model.eval() # Set model to evaluation mode self.model.to(self.device) self.config = config + self.reuse_neighbors = reuse_neighbors + self.supports_neighbor_cache = hasattr(self.model, "forward_graph") + self.skin = max(float(skin), 0.0) + self._neighbor_topology = None + self._reference_positions = None + self._reference_cell = None + self._reference_numbers = None + self._reference_pbc = None + self._neighbor_cache_stats = {"rebuilds": 0, "reuses": 0} + + @property + def neighbor_cache_stats(self): + return dict(self._neighbor_cache_stats) + + def reset_neighbor_cache(self): + self._neighbor_topology = None + self._reference_positions = None + self._reference_cell = None + self._reference_numbers = None + self._reference_pbc = None + + def _prepare_atoms(self): + if not self.atoms.pbc.any(): + print("Non-periodic system detected. Automatically adding a large vacuum box for calculation.") + calc_atoms = self.atoms.copy() + padding = 20.0 + new_cell_dims = calc_atoms.get_positions().ptp(axis=0) + padding + calc_atoms.set_cell(np.diag(new_cell_dims)) + calc_atoms.center() + calc_atoms.pbc = True + return calc_atoms + return self.atoms + + def _max_displacement_since_rebuild(self, positions, cell, pbc): + if self._reference_positions is None or self._reference_cell is None: + return float("inf") + aligned_positions = self._align_positions_to_reference(positions, cell, pbc) + delta_cart = aligned_positions - self._reference_positions + distances = np.linalg.norm(delta_cart, axis=1) + return float(np.max(distances)) if distances.size else 0.0 + + def _align_positions_to_reference(self, positions, cell, pbc): + if ( + self._reference_positions is None + or self._reference_cell is None + or not np.allclose(cell, self._reference_cell, atol=1e-12, rtol=0.0) + ): + return positions.copy() + inverse_cell = np.linalg.inv(cell) + current_frac = positions @ inverse_cell + reference_frac = self._reference_positions @ inverse_cell + delta_frac = current_frac - reference_frac + periodic_axes = np.asarray(pbc, dtype=bool) + delta_frac[:, periodic_axes] -= np.round(delta_frac[:, periodic_axes]) + return (reference_frac + delta_frac) @ cell + + def _should_rebuild_topology(self, positions, cell, numbers, pbc): + if not self.reuse_neighbors or self.skin <= 0.0: + return True + if self._neighbor_topology is None: + return True + if self._reference_numbers is None or numbers.shape != self._reference_numbers.shape: + return True + if not np.array_equal(numbers, self._reference_numbers): + return True + if self._reference_pbc is None or not np.array_equal(np.asarray(pbc, dtype=bool), self._reference_pbc): + return True + if self._reference_cell is None or not np.allclose(cell, self._reference_cell, atol=1e-12, rtol=0.0): + return True + return self._max_displacement_since_rebuild(positions, cell, pbc) > (0.5 * self.skin) + + def _build_or_reuse_topology(self, positions, cell_array, numbers, pbc, pos, natoms): + if self._should_rebuild_topology(positions, cell_array, numbers, pbc): + self._neighbor_topology = build_neighbor_topology( + pos=pos.detach(), + natoms=natoms, + cell=torch.tensor(cell_array, dtype=self.precision, device=self.device).detach(), + cutoff=self.config.cutoff, + skin=self.skin, + precision=self.precision, + numbers=numbers, + ) + self._reference_positions = positions.copy() + self._reference_cell = cell_array.copy() + self._reference_numbers = numbers.copy() + self._reference_pbc = np.asarray(pbc, dtype=bool).copy() + self._neighbor_cache_stats["rebuilds"] += 1 + else: + self._neighbor_cache_stats["reuses"] += 1 + return self._neighbor_topology def calculate(self, atoms=None, properties=None, system_changes=all_changes): """ @@ -59,43 +162,16 @@ def calculate(self, atoms=None, properties=None, system_changes=all_changes): """ Calculator.calculate(self, atoms, properties, system_changes) properties = properties or ['energy'] - - # --- Handle Non-Periodic Systems --- - # If the system is not periodic (e.g., a molecule), we create a copy and - # place it in a large box with vacuum padding. This allows the model, - # which assumes periodicity, to treat it as an isolated system. - if not self.atoms.pbc.any(): - print("Non-periodic system detected. Automatically adding a large vacuum box for calculation.") - calc_atoms = self.atoms.copy() - # Add 20 Å of vacuum padding around the molecule - padding = 20.0 - new_cell_dims = calc_atoms.get_positions().ptp(axis=0) + padding - calc_atoms.set_cell(np.diag(new_cell_dims)) - calc_atoms.center() - calc_atoms.pbc = True # Treat it as periodic now - else: - calc_atoms = self.atoms + calc_atoms = self._prepare_atoms() + needs_stress = 'stress' in properties + needs_forces = needs_stress or ('forces' in properties) + grad_enabled = needs_forces or needs_stress # --- Prepare Tensors for the Model --- - z = torch.tensor( - [atomic_numbers[atom.symbol] for atom in calc_atoms], - dtype=torch.long, - device=self.device - ) - pos = torch.tensor( - calc_atoms.get_positions(), - dtype=self.precision, - device=self.device, - requires_grad=(self.config.compute_forces) - ) - - # Cell should only be provided if the system is periodic - cell = torch.tensor( - calc_atoms.get_cell(complete=True), - dtype=self.precision, - device=self.device - ) if calc_atoms.pbc.any() else None - + atomic_numbers = calc_atoms.get_atomic_numbers() + wrapped_positions = calc_atoms.get_positions(wrap=True) + cell_array = np.array(calc_atoms.get_cell(complete=True)) + z = torch.tensor(atomic_numbers, dtype=torch.long, device=self.device) natoms = torch.tensor( [len(calc_atoms)], dtype=torch.int64, @@ -104,8 +180,67 @@ def calculate(self, atoms=None, properties=None, system_changes=all_changes): batch = torch.zeros_like(z).to(self.device) # --- Run Model Inference --- - with torch.set_grad_enabled(self.config.compute_forces): - energy, forces, stress = self.model(pos, z, batch, natoms, cell, "infer") + use_neighbor_cache = ( + self.reuse_neighbors + and self.supports_neighbor_cache + and calc_atoms.pbc.any() + and not needs_stress + ) + positions_for_model = wrapped_positions + if use_neighbor_cache: + positions_for_model = self._align_positions_to_reference( + wrapped_positions, + cell_array, + calc_atoms.pbc, + ) + pos = torch.tensor( + positions_for_model, + dtype=self.precision, + device=self.device, + requires_grad=grad_enabled, + ) + cell = torch.tensor( + cell_array, + dtype=self.precision, + device=self.device + ) if calc_atoms.pbc.any() else None + with torch.set_grad_enabled(grad_enabled): + if use_neighbor_cache: + topology = self._build_or_reuse_topology( + positions=positions_for_model, + cell_array=cell_array, + numbers=atomic_numbers, + pbc=calc_atoms.pbc, + pos=pos, + natoms=natoms, + ) + graph_data = graph_from_neighbor_topology( + pos=pos, + z=z, + natoms=natoms, + batch=batch, + topology=topology, + cell=cell, + cutoff=self.config.cutoff, + dtype=self.precision, + ) + energy, forces, stress = self.model.forward_graph( + graph_data, + prefix="infer", + compute_forces=needs_forces, + compute_stress=False, + ) + else: + energy, forces, stress = self.model( + pos, + z, + batch, + natoms, + cell, + "infer", + compute_forces=needs_forces, + compute_stress=needs_stress, + ) # --- Store Results --- self.results['energy'] = energy.detach().cpu().item() @@ -125,5 +260,3 @@ def calculate(self, atoms=None, properties=None, system_changes=all_changes): stress_matrix[0, 2], # xz stress_matrix[0, 1] # xy ]) - - diff --git a/alphanet/infer/lammps_mliap_alphanet.py b/alphanet/infer/lammps_mliap_alphanet.py new file mode 100644 index 0000000..073e0fa --- /dev/null +++ b/alphanet/infer/lammps_mliap_alphanet.py @@ -0,0 +1,431 @@ +import logging +import math +from typing import Dict, Tuple, Optional +from math import pi + +import torch +import torch.nn.functional as F +from torch import nn, Tensor +from torch_scatter import scatter +from ase.data import chemical_symbols + +# Import custom model modules +from alphanet.models.alphanet import AlphaNet +from alphanet.models.model import AlphaNetWrapper +# from alphanet.models.graph import GraphData, get_max_neighbors_mask # Uncomment if needed + +# Try to import LAMMPS interface +try: + from lammps.mliap.mliap_unified_abc import MLIAPUnified +except ImportError: + class MLIAPUnified: + def __init__(self): pass + print("Warning: LAMMPS MLIAP-Unified interface not found. Creating dummy class.") + + +class LAMMPS_MP(torch.autograd.Function): + """ + Handles MPI communication for gradients between Local and Ghost atoms within LAMMPS. + """ + @staticmethod + def forward(ctx, *args): + feats, data = args + ctx.vec_len = feats.shape[-1] + ctx.data = data + + out = torch.empty_like(feats) + if not feats.is_contiguous(): + feats = feats.contiguous() + + # Forward exchange: Ghost atoms get data from their Local owners on other procs + data.forward_exchange(feats, out, ctx.vec_len) + return out + + @staticmethod + def backward(ctx, *grad_outputs): + (grad,) = grad_outputs + + gout = grad.clone() + if not gout.is_contiguous(): + gout = gout.contiguous() + + # Reverse exchange: Sum gradients from Ghost atoms back to their Local owners + ctx.data.reverse_exchange(grad, gout, ctx.vec_len) + + return gout, None + + +class AlphaNetEdgeForcesWrapper(torch.nn.Module): + """ + Wrapper for AlphaNet to compute forces via edge gradients directly. + """ + def __init__(self, model: AlphaNet): + super().__init__() + + # Copy submodules + self.z_emb = model.z_emb + self.z_emb_ln = model.z_emb_ln + self.radial_emb = model.radial_emb + self.radial_lin = model.radial_lin + self.neighbor_emb = model.neighbor_emb + self.S_vector = model.S_vector + self.lin = model.lin + self.message_layers = model.message_layers + self.FTEs = model.FTEs + self.last_layer = model.last_layer + self.last_layer_quantum = model.last_layer_quantum + + # Copy Parameters + self.a = model.a + self.b = model.b + self.kernel1 = model.kernel1 + + if isinstance(model.kernels_real, torch.Tensor): + self.kernels_real = nn.ParameterList([nn.Parameter(p) for p in model.kernels_real]) + else: + self.kernels_real = model.kernels_real + + if isinstance(model.kernels_imag, torch.Tensor): + self.kernels_imag = nn.ParameterList([nn.Parameter(p) for p in model.kernels_imag]) + else: + self.kernels_imag = model.kernels_imag + + # Metadata + self.cutoff = model.cutoff + self.pi = model.pi + self.eps = 1e-9 + self.hidden_channels = model.hidden_channels + self.chi1 = model.chi1 + self.complex_type = model.complex_type + self.rcutfac = float(model.cutoff) + + self.register_buffer("atomic_numbers", torch.arange(1, 95)) + self.eval() + + def handle_lammps(self, tensor: Tensor, lammps_class: Optional[object], natoms: Tensor) -> Tensor: + """ + Syncs tensor data for ghost atoms if running inside LAMMPS with MPI. + """ + if lammps_class is None: + return tensor + + n_local = int(natoms[0]) + n_total = int(natoms[1]) + n_current = tensor.size(0) + current_dim = tensor.size(1) + + # Pad if necessary + if n_current == n_local and n_total > n_local: + padding = torch.zeros((n_total - n_local, current_dim), dtype=tensor.dtype, device=tensor.device) + tensor_full = torch.cat([tensor, padding], dim=0) + else: + tensor_full = tensor + + if n_total > n_local: + if tensor_full.device != self.a.device: + tensor_full = tensor_full.to(self.a.device) + + orig_dtype = tensor_full.dtype + if orig_dtype != torch.float64: + tensor_full = tensor_full.to(torch.float64) + + target_dim = getattr(lammps_class, "ndescriptors", current_dim) + + # Pad dimension if needed + if current_dim < target_dim: + pad_width = target_dim - current_dim + col_padding = torch.zeros((tensor_full.size(0), pad_width), dtype=tensor_full.dtype, device=tensor_full.device) + tensor_ready = torch.cat([tensor_full, col_padding], dim=1) + else: + tensor_ready = tensor_full + + if not tensor_ready.is_contiguous(): + tensor_ready = tensor_ready.contiguous() + + if tensor_ready.is_cuda: + torch.cuda.synchronize() + + # Sync data using custom autograd function + tensor_synced = LAMMPS_MP.apply(tensor_ready, lammps_class) + + if current_dim < target_dim: + tensor_synced = tensor_synced[:, :current_dim] + if tensor_synced.dtype != orig_dtype: + tensor_synced = tensor_synced.to(orig_dtype) + return tensor_synced + + return tensor_full + + def forward(self, data: Dict[str, Tensor]) -> Tuple[Tensor, Tensor, Tensor]: + pos = data["positions"] + z = data["node_attrs"] + edge_index = data["edge_index"] + edge_vec = data["vectors"] + + # Ensure vectors require grad for force computation + if not edge_vec.requires_grad: + edge_vec = edge_vec.requires_grad_() + + lammps_ptr = data.get("lammps_ptr") + natoms_info = data["natoms"] + + n_local = int(natoms_info[0]) + n_total = pos.size(0) + + dist = torch.linalg.norm(edge_vec, dim=1) + z_emb = self.z_emb_ln(self.z_emb(z)) + radial_emb = self.radial_emb(dist) + radial_hidden = self.radial_lin(radial_emb) + rbounds = 0.5 * (torch.cos(dist * self.pi / self.cutoff) + 1.0) + radial_hidden = rbounds.unsqueeze(-1) * radial_hidden + + s = self.neighbor_emb(z, z_emb, edge_index, radial_hidden) + + vec = torch.zeros(n_local, 3, s.size(1), device=s.device, dtype=s.dtype) + s = s[:n_local] + + j = edge_index[0] + i = edge_index[1] + edge_diff = edge_vec / (dist.unsqueeze(1) + self.eps) + + edge_vec_mean = scatter(edge_vec, i, reduce='mean', dim=0, dim_size=n_total) + edge_cross = torch.cross(edge_vec, edge_vec_mean[i]) + edge_vertical = torch.cross(edge_diff, edge_cross) + edge_frame = torch.cat((edge_diff.unsqueeze(-1), edge_cross.unsqueeze(-1), edge_vertical.unsqueeze(-1)), dim=-1) + + # Sync initial s features + s_full = self.handle_lammps(s, lammps_ptr, natoms_info) + + S_i_j = self.S_vector(s_full, edge_diff.unsqueeze(-1), edge_index, radial_hidden)[:n_local] + sij_flat = S_i_j.reshape(n_local, -1) + sij_full_flat = self.handle_lammps(sij_flat, lammps_ptr, natoms_info) + S_i_j = sij_full_flat.reshape(n_total, 3, -1) + + scalrization1 = torch.sum(S_i_j[i].unsqueeze(2) * edge_frame.unsqueeze(-1), dim=1) + scalrization2 = torch.sum(S_i_j[j].unsqueeze(2) * edge_frame.unsqueeze(-1), dim=1) + scalrization1[:, 1, :] = torch.square(scalrization1[:, 1, :].clone()) + scalrization2[:, 1, :] = torch.square(scalrization2[:, 1, :].clone()) + + scalar3 = (self.lin(torch.permute(scalrization1, (0, 2, 1))) + + torch.permute(scalrization1, (0, 2, 1))[:, :, 0].unsqueeze(2)).squeeze(-1) / math.sqrt(self.hidden_channels) + scalar4 = (self.lin(torch.permute(scalrization2, (0, 2, 1))) + + torch.permute(scalrization2, (0, 2, 1))[:, :, 0].unsqueeze(2)).squeeze(-1) / math.sqrt(self.hidden_channels) + + edge_weight = torch.cat((scalar3, scalar4), dim=-1) * rbounds.unsqueeze(-1) + edge_weight = torch.cat((edge_weight, radial_hidden, radial_emb), dim=-1) + + quantum = torch.einsum('ik,bi->bk', self.kernel1, z_emb[:n_local]) + real, imagine = torch.split(quantum, self.chi1, dim=-1) + quantum = torch.complex(real, imagine) + + rope: Optional[Tensor] = None + + # Message Passing Layers + for i_layer, (message_layer, fte, kernel_real, kernel_imag) in enumerate(zip( + self.message_layers, self.FTEs, self.kernels_real, self.kernels_imag + )): + # Sync s + s_full = self.handle_lammps(s, lammps_ptr, natoms_info) + + # Sync vec + vec_flat = vec.reshape(n_local, -1) + vec_full_flat = self.handle_lammps(vec_flat, lammps_ptr, natoms_info) + vec_full = vec_full_flat.reshape(n_total, 3, -1) + + # Sync rope + if rope is not None: + if rope.is_complex(): + rope_real = torch.view_as_real(rope).flatten(1) + rope_full_real = self.handle_lammps(rope_real, lammps_ptr, natoms_info) + rope_full = torch.view_as_complex(rope_full_real.view(n_total, -1, 2)) + else: + rope_full = self.handle_lammps(rope, lammps_ptr, natoms_info) + else: + rope_full = None + + new_rope_full, ds_full, dvec_full = message_layer(s_full, vec_full, edge_index, radial_emb, edge_weight, edge_diff, rope_full) + + rope = new_rope_full[:n_local] + ds = ds_full[:n_local] + dvec = dvec_full[:n_local] + s = s + ds + vec = vec + dvec + + kerneli = torch.complex(kernel_real, kernel_imag) + quantum = torch.einsum('ikl,bi,bl->bk', kerneli, s.to(self.complex_type), quantum) + quantum = quantum / (self.eps + quantum.abs().to(self.complex_type)) + ds_fte, dvec_fte = fte(s, vec) + s = s + ds_fte + vec = vec + dvec_fte + + # Final Readout + s_per_atom = self.last_layer(s) + self.last_layer_quantum(torch.cat([quantum.real, quantum.imag], dim=-1)) / self.chi1 + node_energy = (self.a[z[:n_local]].unsqueeze(1) * s_per_atom + self.b[z[:n_local]].unsqueeze(1)).squeeze(-1) + + if node_energy.shape[0] > n_local: + s_total = node_energy[:n_local].sum() + else: + s_total = node_energy.sum() + + # Compute Gradients (Force = -dE/dr, computed here as gradients w.r.t edge vectors) + if s_total.grad_fn is not None: + grads = torch.autograd.grad( + outputs=[s_total], + inputs=[edge_vec], + grad_outputs=[torch.ones_like(s_total)], + retain_graph=False, + create_graph=False, + allow_unused=True, + )[0] + if grads is None: + grads = torch.zeros_like(edge_vec) + else: + grads = torch.zeros_like(edge_vec) + + pair_forces = grads + + return s_total, node_energy, pair_forces + + +class LAMMPS_MLIAP_ALPHANET(MLIAPUnified): + def __init__(self, model_wrapper: AlphaNetWrapper, **kwargs): + super().__init__() + + internal_model = model_wrapper.model + internal_model.double() + internal_model.eval() + edge_wrapper = AlphaNetEdgeForcesWrapper(internal_model).eval() + edge_wrapper.double() + + self.model = edge_wrapper + self.element_types = [chemical_symbols[i] for i in range(1, 95)] + self.num_species = 94 + self.rcutfac = 0.5 * float(model_wrapper.model.cutoff) + self.hidden_dim = internal_model.hidden_channels + + target_dim = 3 * self.hidden_dim + self.ndescriptors = target_dim + self.nparams = target_dim + self.dtype = model_wrapper.precision + self.device = "cpu" + self.initialized = False + self.step = 0 + + def _initialize_device(self, data): + if torch.cuda.is_available(): + self.device = torch.device("cuda") + else: + self.device = torch.device("cpu") + self.model = self.model.to(self.device) + self.model.eval() + self.initialized = True + + def compute_forces(self, data): + natoms = data.nlocal + ntotal = data.ntotal + nghosts = ntotal - natoms + npairs = data.npairs + + if not self.initialized: + self.model.to(self.dtype) + self._initialize_device(data) + + if hasattr(data.elems, 'get'): + elems_np = data.elems.get() + else: + elems_np = data.elems + + species = torch.as_tensor(elems_np, dtype=torch.int64, device=self.device) + self.step += 1 + + if natoms == 0 or npairs <= 1: + return + + # 1. Prepare data batch + batch = self._prepare_batch(data, natoms, nghosts, species) + + # 2. Forward pass + batch["vectors"].requires_grad_(True) + total_energy, atom_energies, pair_forces = self.model(batch) + + if self.device.type == "cuda": + torch.cuda.synchronize() + + # 3. Update LAMMPS with energies and forces + self._update_lammps_data(data, atom_energies, pair_forces, natoms, total_energy) + + def _update_lammps_data(self, data, atom_energies, pair_forces, natoms, total_energy): + if self.dtype == torch.float32: + pair_forces = pair_forces.double() + atom_energies = atom_energies.double() + total_energy = total_energy.double() + + atom_energies_cpu = atom_energies[:natoms].detach().cpu().numpy() + + # Update atom energies + if hasattr(data.eatoms, 'get'): + try: data.eatoms[:natoms] = atom_energies_cpu + except: pass + else: + try: + eatoms_tensor = torch.as_tensor(data.eatoms) + eatoms_tensor[:natoms].copy_(torch.from_numpy(atom_energies_cpu)) + except: + pass + + data.energy = total_energy.item() + + final_forces = pair_forces.detach() + + if not final_forces.is_contiguous(): + final_forces = final_forces.contiguous() + + # Update pair forces (GPU or CPU) + if self.device.type == 'cuda': + try: + from torch.utils.dlpack import to_dlpack + from cupy import from_dlpack + force_cupy = from_dlpack(to_dlpack(final_forces)) + data.update_pair_forces_gpu(force_cupy) + except ImportError: + data.update_pair_forces_gpu(final_forces.cpu().numpy()) + else: + final_forces_np = final_forces.numpy() + data.update_pair_forces_gpu(final_forces_np) + + def _prepare_batch(self, data, natoms, nghosts, species) -> Dict[str, object]: + positions = torch.zeros((natoms + nghosts, 3), dtype=self.dtype, device=self.device) + node_attrs = species + 1 + batch_tensor = torch.zeros(natoms, dtype=torch.int64, device=self.device) + natoms_tensor = torch.tensor([natoms, natoms+nghosts], dtype=torch.int64, device=self.device) + + if hasattr(data.rij, 'get'): rij_data = data.rij.get() + else: rij_data = data.rij + + if hasattr(data.pair_i, 'get'): pair_i = data.pair_i.get() + else: pair_i = data.pair_i + + if hasattr(data.pair_j, 'get'): pair_j = data.pair_j.get() + else: pair_j = data.pair_j + + rij_tensor = torch.as_tensor(rij_data, dtype=self.dtype, device=self.device) + target_tensor = torch.as_tensor(pair_i, dtype=torch.int64, device=self.device) + source_tensor = torch.as_tensor(pair_j, dtype=torch.int64, device=self.device) + + edge_index = torch.stack([source_tensor, target_tensor], dim=0) + + return { + "positions": positions, + "vectors": rij_tensor, + "node_attrs": node_attrs, + "edge_index": edge_index, + "batch": batch_tensor, + "natoms": natoms_tensor, + "lammps_ptr": data, + } + + def compute_descriptors(self, data: Dict[str, Tensor]) -> None: + pass + + def compute_gradients(self, data: Dict[str, Tensor]) -> None: + pass \ No newline at end of file diff --git a/alphanet/models/__init__.py b/alphanet/models/__init__.py index 83c1adf..72a417a 100644 --- a/alphanet/models/__init__.py +++ b/alphanet/models/__init__.py @@ -1,2 +1,2 @@ -from .alphanet import AlphaNet +#from .alphanet import AlphaNet #from .alpha_flax import AlphaNet_flax \ No newline at end of file diff --git a/alphanet/models/alpha_haiku.py b/alphanet/models/alpha_haiku.py index ec60ea8..fca7c87 100644 --- a/alphanet/models/alpha_haiku.py +++ b/alphanet/models/alpha_haiku.py @@ -12,6 +12,8 @@ import math from functools import partial +from alphanet.models.zbl_jax import zbl_interaction, get_default_zbl_params + class Config: def __init__(self, **kwargs): for k, v in kwargs.items(): @@ -496,58 +498,38 @@ def __call__(self, data): s = self.last_layer(s) + self.last_layer_quantum(quantum_features) / self.config.main_chi1 a_values = a[z] b_values = b[z] - V_graph = 0 - if self.zbl: - r_e = dist - Z_j = z[j] - Z_i = z[i] - - w = self.fzbl_w # (M,) - b = self.fzbl_b # (M,) - gamma = self.fzbl_gamma - alpha = self.fzbl_alpha - E2 = self.fzbl_E2 - A0 = self.fzbl_A0 - - # compute screening length a per edge: a = gamma * 0.8854 * a0 / (Z1^alpha + Z2^alpha) - denom = jnp.power(Z_j, alpha) + jnp.power(Z_i, alpha) # (E,) - denom = jnp.clip(denom, a_min=1e-12) - a_vals = gamma * 0.8854 * A0 / denom # (E,) - x = r_e / a_vals # (E,) - - # compute phi(x) = sum_i w_i * exp(-b_i * x) (vectorized) - # exp(- x[:,None] * b[None,:]) -> (E, M) - exp_terms = jnp.exp(- x[:, jnp.newaxis] * b[jnp.newaxis, :]) # (E, M) - phi_vals = exp_terms @ w # (E,) - - # pair potential per edge: V_e = Z1*Z2 * E2 * phi / r - V_edge = (Z_j * Z_i * E2) * (phi_vals / r_e) # (E,) - r_cut = 1.0 # you can make this self.fzbl_rcut buffer if you want configurable value - - # compute taper coefficient: cosine cutoff (smooth) - # for r in [0, r_cut]: c = 0.5*(cos(pi * r / r_cut) + 1) - # for r >= r_cut: c = 0 - # for safety, clamp r/r_cut in [0, 1] - xrc = jnp.clip(r_e / r_cut, 0.0, 1.0) # (E,) - # cosine taper - c = 0.5 * (jnp.cos(jnp.pi * xrc) + 1.0) # (E,) - # enforce zero beyond r_cut explicitly (cos already gives 0 at x=1 but clamp keeps numeric safe) - c = jnp.where(r_e >= r_cut, jnp.zeros_like(c), c) - - # apply taper to edge potential - V_edge = V_edge * c + ml_energy = a_values[z] * s.squeeze() + b_values[z] + if self.config.zbl: + # 获取参数 (这里假设使用默认值,如果 config 有则从 config 读) + zbl_params = get_default_zbl_params(self.dtype) + + # 也可以选择将它们注册为不可训练的 hk.parameter 或者常量 + # ... + + V_edge = zbl_interaction( + dist, z[i], z[j], + zbl_params['w'], zbl_params['b'], + zbl_params['gamma'], zbl_params['alpha'], + zbl_params['E2'], zbl_params['A0'] + ) + + # 【关键】原子能量分摊 + # i 是 target 索引 + num_atoms = z.shape[0] + zbl_per_atom = jax.ops.segment_sum(V_edge, i, num_segments=num_atoms) * 0.5 - # aggregate edge energies to graph-level using jax.ops.segment_sum - # Note: JAX uses segment_sum instead of scatter_add - graph_idx = batch[i] # map receiver node -> graph index (E,) - V_graph = jax.ops.segment_sum(V_edge, graph_idx, num_segments=1) / 2.0 + # 加到总原子能量 + total_atom_energy = ml_energy + zbl_per_atom + else: + total_atom_energy = ml_energy if s.ndim == 2: s = a_values[:, None] * s + b_values[:, None] else: s = a_values * s + b_values s = s[:, None] - s = jnp.sum(s)+V_graph#jax.ops.segment_sum(s, batch, num_segments=1)+ Vgraph - return jnp.squeeze(s) + s_total = jax.ops.segment_sum(total_atom_energy, batch, num_segments=1) # 假设 batch_size=1 用于推理 + + return s_total.squeeze() diff --git a/alphanet/models/alphanet.py b/alphanet/models/alphanet.py index 52dec7b..64a6f22 100644 --- a/alphanet/models/alphanet.py +++ b/alphanet/models/alphanet.py @@ -1,41 +1,72 @@ - import math from math import pi -from typing import Optional, Tuple, List, NamedTuple -from typing import Literal +from typing import Optional, Tuple, List + import torch -from torch import nn -from torch import Tensor +from torch import nn, Tensor from torch.nn import Embedding from torch_geometric.nn.conv import MessagePassing -from torch_scatter import scatter, scatter_add + from alphanet.models.graph import GraphData +import numpy as np + + +def scatter(src: torch.Tensor, index: torch.Tensor, dim: int = -1, + out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None, + reduce: str = "sum") -> torch.Tensor: + """ + Drop-in replacement for torch_scatter.scatter using native PyTorch functions. + """ + if out is not None: + dim_size = out.size(dim) + else: + if dim_size is None: + dim_size = int(index.max()) + 1 if index.numel() > 0 else 0 + + out_size = list(src.size()) + out_size[dim] = dim_size + + if index.dim() != src.dim(): + curr_dims = index.dim() + target_dims = src.dim() + for _ in range(target_dims - curr_dims): + index = index.unsqueeze(-1) + index = index.expand_as(src) + + reduce = reduce.lower() + + if reduce in ['sum', 'add']: + if out is None: + out = torch.zeros(out_size, dtype=src.dtype, device=src.device) + return out.scatter_add_(dim, index, src) + + if reduce == 'mean': + mode = 'mean' + init_val = 0.0 + elif reduce in ['min', 'amin']: + mode = 'amin' + init_val = float('inf') + elif reduce in ['max', 'amax']: + mode = 'amax' + init_val = float('-inf') + else: + raise ValueError(f"Unknown reduce mode: {reduce}") + + if out is None: + out = torch.full(out_size, init_val, dtype=src.dtype, device=src.device) + + out.scatter_reduce_(dim, index, src, reduce=mode, include_self=False) + return out class rbf_emb(nn.Module): r_max: float prefactor: float - def __init__(self, num_basis=8, r_max = 5.0, trainable=True): - r"""Radial Bessel Basis, as proposed in DimeNet: https://arxiv.org/abs/2003.03123 - - - Parameters - ---------- - r_max : float - Cutoff radius - - num_basis : int - Number of Bessel Basis functions - - trainable : bool - Train the :math:`n \pi` part or not. - """ + def __init__(self, num_basis=8, r_max=5.0, trainable=True): super(rbf_emb, self).__init__() - self.trainable = trainable self.num_basis = num_basis - self.r_max = r_max self.prefactor = 2.0 / self.r_max @@ -48,61 +79,12 @@ def __init__(self, num_basis=8, r_max = 5.0, trainable=True): self.register_buffer("bessel_weights", bessel_weights) def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Evaluate Bessel Basis for input x. - - Parameters - ---------- - x : torch.Tensor - Input - """ numerator = torch.sin(self.bessel_weights * x.unsqueeze(-1) / self.r_max) - return self.prefactor * (numerator / x.unsqueeze(-1)) - -class _rbf_emb(nn.Module): - ''' - modified: delete cutoff with r - ''' - - def __init__(self, num_rbf, rbound_upper, rbf_trainable=False): - super().__init__() - self.rbound_upper = rbound_upper - self.rbound_lower = 0 - self.num_rbf = num_rbf - self.rbf_trainable = rbf_trainable - self.pi = pi - means, betas = self._initial_params() - - self.register_buffer("means", means) - self.register_buffer("betas", betas) - - def _initial_params(self): - start_value = torch.exp(torch.scalar_tensor(-self.rbound_upper)) - end_value = torch.exp(torch.scalar_tensor(-self.rbound_lower)) - means = torch.linspace(start_value, end_value, self.num_rbf) - betas = torch.tensor([(2 / self.num_rbf * (end_value - start_value)) ** -2] * - self.num_rbf) - return means, betas - - def reset_parameters(self): - means, betas = self._initial_params() - self.means.data.copy_(means) - self.betas.data.copy_(betas) - - def forward(self, dist): - dist = dist.unsqueeze(-1) - rbounds = 0.5 * \ - (torch.cos(dist * self.pi / self.rbound_upper) + 1.0) - rbounds = rbounds * (dist < self.rbound_upper).float() - return rbounds * torch.exp(-self.betas * torch.square((torch.exp(-dist) - self.means))) class NeighborEmb(MessagePassing): - propagate_type = { - 'x': Tensor, - 'norm': Tensor - } + propagate_type = {'x': Tensor, 'norm': Tensor} def __init__(self, hid_dim: int): super(NeighborEmb, self).__init__(aggr='add') @@ -110,13 +92,7 @@ def __init__(self, hid_dim: int): self.hid_dim = hid_dim self.ln_emb = nn.LayerNorm(hid_dim, elementwise_affine=False) - def forward( - self, - z: Tensor, - s: Tensor, - edge_index: Tensor, - embs: Tensor - ) -> Tensor: + def forward(self, z: Tensor, s: Tensor, edge_index: Tensor, embs: Tensor) -> Tensor: s_neighbors = self.ln_emb(self.embedding(z)) s_neighbors = self.propagate(edge_index, x=s_neighbors, norm=embs) s = s + s_neighbors @@ -127,10 +103,7 @@ def message(self, x_j: Tensor, norm: Tensor) -> Tensor: class S_vector(MessagePassing): - propagate_type = { - 'x': Tensor, - 'norm': Tensor - } + propagate_type = {'x': Tensor, 'norm': Tensor} def __init__(self, hid_dim: int): super(S_vector, self).__init__(aggr='add') @@ -140,13 +113,7 @@ def __init__(self, hid_dim: int): nn.LayerNorm(hid_dim, elementwise_affine=False), nn.SiLU()) - def forward( - self, - s: Tensor, - v: Tensor, - edge_index: Tensor, - emb: Tensor - ) -> Tensor: + def forward(self, s: Tensor, v: Tensor, edge_index: Tensor, emb: Tensor) -> Tensor: s = self.lin1(s) emb = emb.unsqueeze(1) * v v = self.propagate(edge_index, x=s, norm=emb) @@ -157,13 +124,10 @@ def message(self, x_j: Tensor, norm: Tensor) -> Tensor: a = norm.view(-1, 3, self.hid_dim) * x_j return a.view(-1, 3 * self.hid_dim) -class EquiMessagePassing(MessagePassing): +class EquiMessagePassing(MessagePassing): propagate_type = { - 'xh': Tensor, - 'vec': Tensor, - 'rbfh_ij': Tensor, - 'r_ij': Tensor + 'xh': Tensor, 'vec': Tensor, 'rbfh_ij': Tensor, 'r_ij': Tensor } def __init__( @@ -177,7 +141,7 @@ def __init__( has_dropout_flag: bool = False, has_norm_before_flag=True, has_norm_after_flag=False, - complex_type = torch.complex64, + complex_type=torch.complex64, reduce_mode='sum', device=torch.device('cuda') if torch.cuda.is_available() else torch.device("cpu") ): @@ -193,9 +157,12 @@ def __init__( self.hidden_channels_chi = hidden_channels_chi self.scale = nn.Linear(self.hidden_channels, self.hidden_channels_chi * 2) self.num_radial = num_radial + self.dir_proj = nn.Sequential( - nn.Linear(3 * self.hidden_channels + self.num_radial, self.hidden_channels * 3), nn.SiLU(inplace=True), - nn.Linear(self.hidden_channels * 3, self.hidden_channels * 3), ) + nn.Linear(3 * self.hidden_channels + self.num_radial, self.hidden_channels * 3), + nn.SiLU(inplace=True), + nn.Linear(self.hidden_channels * 3, self.hidden_channels * 3) + ) self.x_proj = nn.Sequential( nn.Linear(hidden_channels, hidden_channels), @@ -217,26 +184,26 @@ def __init__( self.dx_layer_norm = nn.LayerNorm(self.chi1) if self.has_norm_before_flag: self.dx_layer_norm = nn.LayerNorm(self.chi1 + self.hidden_channels) + self.dropout = nn.Dropout(p=0.5) self.diachi1 = torch.nn.Parameter(torch.randn((self.chi1), device=self.device)) self.scale2 = nn.Sequential( - nn.Linear(self.chi1, hidden_channels//2), + nn.Linear(self.chi1, hidden_channels // 2), ) self.kernel_real = torch.nn.Parameter(torch.randn((self.head + 1, (self.hidden_channels_chi) // self.head, self.chi2))) self.kernel_imag = torch.nn.Parameter(torch.randn((self.head + 1, (self.hidden_channels_chi) // self.head, self.chi2))) - self.fc_mps = nn.Linear(self.chi1, self.chi1)#.to(torch.cfloat) - self.fc_dx = nn.Linear(self.chi1, hidden_channels)#.to(torch.cfloat) - self.dia = nn.Linear(self.chi1, self.chi1)#.to(torch.cfloat) + self.fc_mps = nn.Linear(self.chi1, self.chi1) + self.fc_dx = nn.Linear(self.chi1, hidden_channels) + self.dia = nn.Linear(self.chi1, self.chi1) self.unitary = torch.nn.Parameter(torch.randn((self.chi1, self.chi1), device=self.device)) self.activation = nn.SiLU() self.inv_sqrt_3 = 1 / math.sqrt(3.0) self.inv_sqrt_h = 1 / math.sqrt(hidden_channels) - self.x_layernorm = nn.LayerNorm(hidden_channels) - + self.reset_parameters() def reset_parameters(self): @@ -247,7 +214,6 @@ def reset_parameters(self): nn.init.xavier_uniform_(self.rbf_proj.weight) self.rbf_proj.bias.data.fill_(0) self.x_layernorm.reset_parameters() - nn.init.xavier_uniform_(self.dir_proj[0].weight) self.dir_proj[0].bias.data.fill_(0) @@ -265,16 +231,16 @@ def forward( rope: Optional[Tensor] = None ) -> Tuple[Tensor, Tensor, Tensor]: if rope is not None: - real, imag = torch.split(x, [self.hidden_channels//2, self.hidden_channels//2], dim=-1) + real, imag = torch.split(x, [self.hidden_channels // 2, self.hidden_channels // 2], dim=-1) dy_pre = torch.complex(real=real, imag=imag) - dy_pre = dy_pre* rope + dy_pre = dy_pre * rope x = torch.cat([dy_pre.real, dy_pre.imag], dim=-1) + xh = self.x_proj(self.x_layernorm(x)) - rbfh = self.rbf_proj(edge_rbf) weight = self.dir_proj(weight) rbfh = rbfh * weight - # propagate_type: (xh: Tensor, vec: Tensor, rbfh_ij: Tensor, r_ij: Tensor) + dx, dvec = self.propagate( edge_index, xh=xh, @@ -282,8 +248,8 @@ def forward( rbfh_ij=rbfh, r_ij=edge_vector, size=None, - # rotation = unitary, ) + if self.has_norm_before_flag: dx = self.dx_layer_norm(dx) @@ -293,7 +259,6 @@ def forward( dx = self.dx_layer_norm(dx) dx = self.scale2(dx) - dx = torch.complex(torch.cos(dx), torch.sin(dx)) return dx, dy, dvec @@ -309,21 +274,23 @@ def message(self, xh_j, vec_j, rbfh_ij, r_ij): real = self.dropout(real) imagine = self.dropout(imagine) - # complex invariant quantum state phi = torch.complex(real, imagine) q = phi - a = torch.ones(q.shape[0], 1, (self.hidden_channels_chi) // self.head, device=self.device, dtype= self.complex_type) + a = torch.ones(q.shape[0], 1, (self.hidden_channels_chi) // self.head, device=q.device, dtype=self.complex_type) kernel = (torch.complex(self.kernel_real, self.kernel_imag) / math.sqrt((self.hidden_channels) // self.head)).expand(q.shape[0], -1, -1, -1) + equation = 'ijl, ijlk->ik' - conv = torch.einsum(equation, torch.cat([a, q], dim=1), kernel.to( self.complex_type)) + conv = torch.einsum(equation, torch.cat([a, q], dim=1), kernel.to(self.complex_type)) a = 1.0 * self.activation(self.diagonal(rbfh_ij)) - b = a.unsqueeze(-1) * self.diachi1.unsqueeze(0).unsqueeze(0) + torch.ones(kernel.shape[0], self.chi2, self.chi1, device=self.device) + b = a.unsqueeze(-1) * self.diachi1.unsqueeze(0).unsqueeze(0) + torch.ones(kernel.shape[0], self.chi2, self.chi1, device=rbfh_ij.device) dia = self.dia(b) + equation = 'ik,ikl->il' kernel = torch.einsum(equation, conv, dia.to(self.complex_type)) - kernel_real,kernel_imag = kernel.real,kernel.imag - kernel_real,kernel_imag = self.fc_mps(kernel_real),self.fc_mps(kernel_imag) + kernel_real, kernel_imag = kernel.real, kernel.imag + kernel_real, kernel_imag = self.fc_mps(kernel_real), self.fc_mps(kernel_imag) kernel = torch.angle(torch.complex(kernel_real, kernel_imag)) + agg = torch.cat([kernel, x], dim=-1) vec = vec_j * xh2.unsqueeze(1) + xh3.unsqueeze(1) * r_ij.unsqueeze(2) vec = vec * self.inv_sqrt_h @@ -339,7 +306,7 @@ def aggregate( ) -> Tuple[torch.Tensor, torch.Tensor]: x, vec = features x = scatter(x, index, dim=self.node_dim, dim_size=dim_size, reduce=self.reduce_mode) - vec = scatter(vec, index, dim=self.node_dim, dim_size=dim_size) + vec = scatter(vec, index, dim=self.node_dim, dim_size=dim_size, reduce='sum') return x, vec def update( @@ -361,12 +328,10 @@ def __init__(self, hidden_channels): nn.Linear(hidden_channels * 2, hidden_channels), nn.SiLU(), nn.Linear(hidden_channels, hidden_channels * 3) - ) self.inv_sqrt_2 = 1 / math.sqrt(2.0) self.inv_sqrt_h = 1 / math.sqrt(hidden_channels) - self.reset_parameters() def reset_parameters(self): @@ -378,50 +343,29 @@ def reset_parameters(self): def forward(self, x, vec): vec = self.vec_proj(vec) - vec1, vec2 = torch.split( - vec, self.hidden_channels, dim=-1 - ) + vec1, vec2 = torch.split(vec, self.hidden_channels, dim=-1) - scalar = torch.norm(vec1, dim=-2, p=1) + scalar = torch.sum(vec1**2, dim=-2) vec_dot = (vec1 * vec2).sum(dim=1) vec_dot = vec_dot * self.inv_sqrt_h - x_vec_h = self.xvec_proj( - torch.cat( - [x, scalar], dim=-1 - ) - ) - xvec1, xvec2, xvec3 = torch.split( - x_vec_h, self.hidden_channels, dim=-1 - ) + x_vec_h = self.xvec_proj(torch.cat([x, scalar], dim=-1)) + xvec1, xvec2, xvec3 = torch.split(x_vec_h, self.hidden_channels, dim=-1) dx = xvec1 + xvec2 + vec_dot dx = dx * self.inv_sqrt_2 dvec = xvec3.unsqueeze(1) * vec2 - return dx, dvec -class aggregate_pos(MessagePassing): - - def __init__(self, aggr='mean'): - super(aggregate_pos, self).__init__(aggr=aggr) - - def forward(self, vector, edge_index): - v = self.propagate(edge_index, x=vector) - - return v - - class AlphaNet(nn.Module): - def __init__(self, config, device=torch.device('cuda') if torch.cuda.is_available() else torch.device("cpu")): super(AlphaNet, self).__init__() self.device = device self.complex_type = torch.complex64 if config.dtype == "32" else torch.complex128 - self.eps = config.eps + self.eps = 1e-9 self.num_layers = config.num_layers self.hidden_channels = config.hidden_channels self.a = nn.Parameter(torch.ones(108) * config.a) @@ -434,6 +378,7 @@ def __init__(self, config, device=torch.device('cuda') if torch.cuda.is_availabl self.num_targets = config.output_dim if config.output_dim != 0 else 1 self.compute_forces = config.compute_forces self.compute_stress = config.compute_stress + self.z_emb_ln = nn.LayerNorm(config.hidden_channels, elementwise_affine=False) self.z_emb = Embedding(95, config.hidden_channels) self.kernel1 = torch.nn.Parameter(torch.randn((config.hidden_channels, self.chi1 * 2), device=self.device)) @@ -455,11 +400,10 @@ def __init__(self, config, device=torch.device('cuda') if torch.cuda.is_availabl self.kernels_real = [] self.kernels_imag = [] self.zbl = config.zbl + if self.zbl: - M = 8 self.register_buffer('fzbl_w', torch.tensor([0.187,0.3769,0.189,0.081,0.003,0.037,0.0546,0.0715], dtype=torch.get_default_dtype())) self.register_buffer('fzbl_b', torch.tensor([3.20,1.10,0.102,0.958,1.28,1.14,1.69,5], dtype=torch.get_default_dtype())) - # normalize weights just in case with torch.no_grad(): w = getattr(self, 'fzbl_w') w = w.clamp(min=0.0) @@ -468,8 +412,6 @@ def __init__(self, config, device=torch.device('cuda') if torch.cuda.is_availabl self.register_buffer('fzbl_gamma', torch.tensor(1.001, dtype=torch.get_default_dtype())) self.register_buffer('fzbl_alpha', torch.tensor(0.6032, dtype=torch.get_default_dtype())) - - # physics constants self.register_buffer('fzbl_E2', torch.tensor(14.399645478425, dtype=torch.get_default_dtype())) # eV·Å self.register_buffer('fzbl_A0', torch.tensor(0.529177210903, dtype=torch.get_default_dtype())) # Å @@ -485,7 +427,7 @@ def __init__(self, config, device=torch.device('cuda') if torch.cuda.is_availabl has_norm_before_flag=config.has_norm_before_flag, has_norm_after_flag=config.has_norm_after_flag, hidden_channels_chi=config.hidden_channels_chi, - complex_type = self.complex_type, + complex_type=self.complex_type, device=device, reduce_mode=config.reduce_mode ) @@ -520,15 +462,23 @@ def reset_parameters(self): if hasattr(layer, 'reset_parameters'): layer.reset_parameters() - def forward(self, data: GraphData, prefix: str): - + def forward( + self, + data: GraphData, + prefix: str, + compute_forces: Optional[bool] = None, + compute_stress: Optional[bool] = None, + return_atom_energy: bool = False, + ): + compute_forces = self.compute_forces if compute_forces is None else compute_forces + compute_stress = self.compute_stress if compute_stress is None else compute_stress pos = data.pos batch = data.batch z = data.z.long() edge_index = data.edge_index - dist = data.edge_attr vecs = data.edge_vec + dist = torch.linalg.norm(vecs, dim=1) z_emb = self.z_emb_ln(self.z_emb(z)) radial_emb = self.radial_emb(dist) radial_hidden = self.radial_lin(radial_emb) @@ -542,18 +492,19 @@ def forward(self, data: GraphData, prefix: str): i = edge_index[1] edge_diff = vecs edge_diff = edge_diff / (dist.unsqueeze(1) + self.eps) - mean = scatter(pos[edge_index[0]], edge_index[1], reduce='mean', dim=0) - edge_cross = torch.cross(pos[i]-mean[i], pos[j]-mean[i]) + edge_vec_mean = scatter(vecs, i, reduce='mean', dim=0) + edge_cross = torch.cross(vecs, edge_vec_mean[i]) edge_vertical = torch.cross(edge_diff, edge_cross) edge_frame = torch.cat((edge_diff.unsqueeze(-1), edge_cross.unsqueeze(-1), edge_vertical.unsqueeze(-1)), dim=-1) S_i_j = self.S_vector(s, edge_diff.unsqueeze(-1), edge_index, radial_hidden) + scalrization1 = torch.sum(S_i_j[i].unsqueeze(2) * edge_frame.unsqueeze(-1), dim=1) scalrization2 = torch.sum(S_i_j[j].unsqueeze(2) * edge_frame.unsqueeze(-1), dim=1) - scalrization1[:, 1, :] = torch.abs(scalrization1[:, 1, :].clone()) - scalrization2[:, 1, :] = torch.abs(scalrization2[:, 1, :].clone()) - + scalrization1[:, 1, :] = torch.square(scalrization1[:, 1, :].clone()) + scalrization2[:, 1, :] = torch.square(scalrization2[:, 1, :].clone()) + scalar3 = (self.lin(torch.permute(scalrization1, (0, 2, 1))) + torch.permute(scalrization1, (0, 2, 1))[:, :, 0].unsqueeze(2)).squeeze(-1) / math.sqrt(self.hidden_channels) scalar4 = (self.lin(torch.permute(scalrization2, (0, 2, 1))) + @@ -582,87 +533,50 @@ def forward(self, data: GraphData, prefix: str): equation = 'ikl,bi,bl->bk' kerneli = torch.complex(kernel_real, kernel_imag) quantum = torch.einsum(equation, kerneli, s.to(self.complex_type), quantum) - quantum = quantum / quantum.abs().to(self.complex_type) + quantum = quantum / (self.eps + quantum.abs().to(self.complex_type)) ds, dvec = fte(s, vec) s = s + ds vec = vec + dvec s = self.last_layer(s) + self.last_layer_quantum(torch.cat([quantum.real, quantum.imag], dim=-1)) / self.chi1 - V_graph = 0 - if self.zbl: - r_e = dist - Z_j = z[j] - Z_i = z[i] - - # load constants (buffers) and cast to pos dtype/device - w = self.fzbl_w.to(device=s.device, dtype=s.dtype) # (M,) - b = self.fzbl_b.to(device=s.device, dtype=s.dtype) # (M,) - gamma = self.fzbl_gamma.to(device=s.device, dtype=s.dtype) - alpha = self.fzbl_alpha.to(device=s.device, dtype=s.dtype) - E2 = self.fzbl_E2.to(device=s.device, dtype=s.dtype) - A0 = self.fzbl_A0.to(device=s.device, dtype=s.dtype) - - # compute screening length a per edge: a = gamma * 0.8854 * a0 / (Z1^alpha + Z2^alpha) - denom = torch.pow(Z_j, alpha) + torch.pow(Z_i, alpha) # (E,) - denom = torch.clamp(denom, min=1e-12) - a_vals = gamma * 0.8854 * A0 / denom # (E,) - x = r_e / a_vals # (E,) - - # compute phi(x) = sum_i w_i * exp(-b_i * x) (vectorized) - # exp(- x[:,None] * b[None,:]) -> (E, M) - exp_terms = torch.exp(- x.unsqueeze(1) * b.unsqueeze(0)) # (E, M) - phi_vals = exp_terms.matmul(w) # (E,) - - # pair potential per edge: V_e = Z1*Z2 * E2 * phi / r - V_edge = (Z_j * Z_i * E2) * (phi_vals / r_e) # (E,) - r_cut = 1.0 # you can make this self.fzbl_rcut buffer if you want configurable value - - # compute taper coefficient: cosine cutoff (smooth) - # for r in [0, r_cut]: c = 0.5*(cos(pi * r / r_cut) + 1) - # for r >= r_cut: c = 0 - # for safety, clamp r/r_cut in [0, 1] - xrc = (r_e / r_cut).clamp(min=0.0, max=1.0) # (E,) - # cosine taper - c = 0.5 * (torch.cos(torch.pi * xrc) + 1.0) # (E,) - # enforce zero beyond r_cut explicitly (cos already gives 0 at x=1 but clamp keeps numeric safe) - c = torch.where(r_e >= r_cut, torch.zeros_like(c), c) - - # apply taper to edge potential - V_edge = V_edge * c - # aggregate edge energies to graph-level (use receiver node's batch index) - # use torch_scatter.scatter_add (or your existing scatter) to sum per-graph - - graph_idx = batch[i] # map receiver node -> graph index (E,) - V_graph = scatter_add(V_edge, graph_idx, dim=0) / 2.0 + if s.dim() == 2: s = (self.a[z].unsqueeze(1) * s + self.b[z].unsqueeze(1)) elif s.dim() == 1: s = (self.a[z] * s + self.b[z]).unsqueeze(1) else: raise ValueError(f"Unexpected shape of s: {s.shape}") - #print(s.shape, V_graph.shape, batch.shape) - s = scatter(s, batch, dim=0, reduce=self.readout).squeeze()#+ V_graph - #print(s.shape) + + atom_energy = s.squeeze(-1) if s.dim() == 2 and s.size(-1) == 1 else s + total_energy = scatter(atom_energy, batch, dim=0, reduce=self.readout).squeeze() + if self.use_sigmoid: - s = torch.sigmoid((s - 0.5) * 5) - #return s, None, None - if self.compute_forces and self.compute_stress: + if return_atom_energy: + raise ValueError("Per-atom energy is not defined when sigmoid readout is enabled") + total_energy = torch.sigmoid((total_energy - 0.5) * 5) + if compute_forces and compute_stress: if data.displacement is not None: - stress, forces = self.cal_stress_and_force(s, pos, data.displacement, data.cell, prefix) + stress, forces = self.cal_stress_and_force(total_energy, pos, data.displacement, data.cell, prefix) stress = stress.view(-1, 3) else: stress = None forces = None - return s, forces, stress - elif self.compute_forces: - forces = self.cal_forces(s, pos, prefix) - return s, forces, None - return s, None, None + if return_atom_energy: + return total_energy, forces, stress, atom_energy + return total_energy, forces, stress + elif compute_forces: + forces = self.cal_forces(total_energy, pos, prefix) + if return_atom_energy: + return total_energy, forces, None, atom_energy + return total_energy, forces, None + + if return_atom_energy: + return total_energy, None, None, atom_energy + return total_energy, None, None def cal_forces(self, energy, positions, prefix: str = 'infer'): - graph = (prefix == "train") grad_outputs = torch.jit.annotate(List[Optional[torch.Tensor]], [torch.ones_like(energy)]) forces = torch.autograd.grad( @@ -676,9 +590,9 @@ def cal_forces(self, energy, positions, prefix: str = 'infer'): assert forces is not None, "Gradient should not be None" return -forces - def cal_stress_and_force(self, energy: Tensor,positions: Tensor, displacement: Optional[Tensor], cell: Tensor, prefix: str) -> Tuple[Tensor, Tensor]: + def cal_stress_and_force(self, energy: Tensor, positions: Tensor, displacement: Optional[Tensor], cell: Tensor, prefix: str) -> Tuple[Tensor, Tensor]: if displacement is None: - raise ValueError("displacement cannot be None for stress calculation") + raise ValueError("displacement cannot be None for stress calculation") graph = (prefix == "train") grad_outputs = torch.jit.annotate(List[Optional[torch.Tensor]], [torch.ones_like(energy)]) output = torch.autograd.grad( @@ -694,10 +608,7 @@ def cal_stress_and_force(self, energy: Tensor,positions: Tensor, displacement: O volume = torch.abs(torch.linalg.det(cell)) volume_expanded = volume.reshape(-1, 1, 1) stress = virial / volume_expanded - force =output[1] + force = output[1] assert force is not None, "Forces tensor should not be None" return stress, -force - - - diff --git a/alphanet/models/graph.py b/alphanet/models/graph.py index 8e18476..dd43ed5 100644 --- a/alphanet/models/graph.py +++ b/alphanet/models/graph.py @@ -1,8 +1,11 @@ +from dataclasses import dataclass +from typing import List, NamedTuple, Optional, Tuple + +import numpy as np import torch +from ase import Atoms +from matscipy.neighbours import neighbour_list from torch import Tensor -from typing import Optional, Tuple, NamedTuple, List -from torch_scatter import segment_coo, segment_csr - class GraphData(NamedTuple): @@ -10,7 +13,7 @@ class GraphData(NamedTuple): batch: Tensor z: Tensor natoms: Tensor - edge_index: Tensor#Tuple[Tensor, Tensor] + edge_index: Tensor edge_attr: Tensor edge_vec: Tensor cell: Tensor = None @@ -18,315 +21,223 @@ class GraphData(NamedTuple): displacement: Optional[Tensor] = None pbc: Optional[Tensor] = None -def get_max_neighbors_mask( - natoms: Tensor, - index: Tensor, - atom_distance: Tensor, - max_num_neighbors_threshold: int, - precision: torch.dtype -) : - """ - Give a mask that filters out edges so that each atom has at most - `max_num_neighbors_threshold` neighbors. - Assumes that `index` is sorted. - """ - device = natoms.device - num_atoms = natoms.sum() - - # Get number of neighbors - # segment_coo assumes sorted index - ones = index.new_ones(1).expand_as(index) - num_neighbors = segment_coo(ones, index, dim_size=num_atoms) - max_num_neighbors = num_neighbors.max() - num_neighbors_thresholded = num_neighbors.clamp( - max=max_num_neighbors_threshold - ) - # Get number of (thresholded) neighbors per image - image_indptr = torch.zeros( - natoms.shape[0] + 1, device=device, dtype=torch.long - ) - image_indptr[1:] = torch.cumsum(natoms, dim=0) - num_neighbors_image = segment_csr(num_neighbors_thresholded, image_indptr) - - # If max_num_neighbors is below the threshold, return early - if ( - max_num_neighbors <= max_num_neighbors_threshold - or max_num_neighbors_threshold <= 0 - ): - mask_num_neighbors = torch.tensor( - [True], dtype=torch.bool, device=device - ).expand_as(index) - return mask_num_neighbors, num_neighbors_image - - # Create a tensor of size [num_atoms, max_num_neighbors] to sort the distances of the neighbors. - # Fill with infinity so we can easily remove unused distances later. - #distance_sort = torch.full( - # [num_atoms * max_num_neighbors], np.inf, device=device - #) - distance_sort = torch.ones( - num_atoms * max_num_neighbors, - device=device - ) * float('inf') - distance_sort = distance_sort.to(precision) - # Create an index map to map distances from atom_distance to distance_sort - # index_sort_map assumes index to be sorted - index_neighbor_offset = torch.cumsum(num_neighbors, dim=0) - num_neighbors - index_neighbor_offset_expand = torch.repeat_interleave( - index_neighbor_offset, num_neighbors - ) - index_sort_map = ( - index * max_num_neighbors - + torch.arange(len(index), device=device) - - index_neighbor_offset_expand +@dataclass +class NeighborTopology: + edge_index: Tensor + cell_offsets: Tensor + neighbors: Tensor + reference_positions: Tensor + reference_cell: Tensor + cutoff: float + skin: float + + +def _to_numpy_array(data) -> np.ndarray: + if isinstance(data, torch.Tensor): + return data.detach().cpu().numpy() + return np.asarray(data) + + +def _build_single_image_topology_matscipy( + positions: np.ndarray, + cell: np.ndarray, + numbers: np.ndarray, + radius: float, + pbc: np.ndarray, + edge_source_first: bool, +) -> Tuple[np.ndarray, np.ndarray, int]: + atoms = Atoms( + numbers=numbers, + positions=positions, + cell=cell, + pbc=pbc, ) - distance_sort.index_copy_(0, index_sort_map, atom_distance) - distance_sort = distance_sort.view(num_atoms, max_num_neighbors) - - # Sort neighboring atoms based on distance - distance_sort, index_sort = torch.sort(distance_sort, dim=1) - # Select the max_num_neighbors_threshold neighbors that are closest - distance_sort = distance_sort[:, :max_num_neighbors_threshold] - index_sort = index_sort[:, :max_num_neighbors_threshold] - - # Offset index_sort so that it indexes into index - index_sort = index_sort + index_neighbor_offset.view(-1, 1).expand( - -1, max_num_neighbors_threshold + index_i, index_j, shift = neighbour_list( + quantities="ijS", + atoms=atoms, + cutoff=radius, ) - # Remove "unused pairs" with infinite distances - mask_finite = torch.isfinite(distance_sort) - index_sort = torch.masked_select(index_sort, mask_finite) - # At this point index_sort contains the index into index of the - # closest max_num_neighbors_threshold neighbors per atom - # Create a mask to remove all pairs not in index_sort - mask_num_neighbors = torch.zeros(len(index), device=device, dtype=torch.bool) - mask_num_neighbors.index_fill_(0, index_sort, torch.tensor(True, device=device)) + index_i = np.asarray(index_i, dtype=np.int64) + index_j = np.asarray(index_j, dtype=np.int64) + shift = np.asarray(shift, dtype=np.int32) + + if index_i.size == 0: + edge_index = np.empty((2, 0), dtype=np.int64) + cell_offsets = np.empty((0, 3), dtype=np.int32) + return edge_index, cell_offsets, 0 + + if edge_source_first: + edge_index = np.stack((index_j, index_i), axis=0) + else: + edge_index = np.stack((index_i, index_j), axis=0) + + return edge_index, shift, int(index_i.shape[0]) - return mask_num_neighbors, num_neighbors_image def check_and_reshape_cell(cell: Optional[torch.Tensor]) -> torch.Tensor: - if cell is None: - return torch.eye(3, dtype=torch.float32).unsqueeze(0) - + return torch.eye(3, dtype=torch.float32).unsqueeze(0) if cell.dim() == 2 and cell.size(0) % 3 == 0 and cell.size(1) == 3: batch_size = cell.size(0) // 3 cell = cell.reshape(batch_size, 3, 3) elif cell.dim() != 3 or cell.size(1) != 3 or cell.size(2) != 3: raise ValueError(f"Invalid cell shape. Expected (batch_size, 3, 3), but got {cell.size()}") - + return cell -def radius_graph_pbc( - pos: Tensor, - natoms: Tensor, - cell: Tensor, - radius: float, - max_num_neighbors_threshold: int, - pbc: Optional[List[bool]] = None, - precision: torch.dtype = torch.float32 -): - if pbc is None: - pbc = [True, True, True] - device = pos.device - batch_size = len(natoms) - atom_pos = pos - num_atoms_per_image = natoms - num_atoms_per_image_sqr = (num_atoms_per_image**2).long() - index_offset = ( - torch.cumsum(num_atoms_per_image, dim=0) - num_atoms_per_image - ) - index_offset_expand = torch.repeat_interleave( - index_offset, num_atoms_per_image_sqr - ) - num_atoms_per_image_expand = torch.repeat_interleave( - num_atoms_per_image, num_atoms_per_image_sqr - ) - # Compute a tensor containing sequences of numbers that range from 0 to num_atoms_per_image_sqr for each image - # that is used to compute indices for the pairs of atoms. This is a very convoluted way to implement - # the following (but 10x faster since it removes the for loop) - # for batch_idx in range(batch_size): - # batch_count = torch.cat([batch_count, torch.arange(num_atoms_per_image_sqr[batch_idx], device=device)], dim=0) - num_atom_pairs = torch.sum(num_atoms_per_image_sqr) - index_sqr_offset = ( - torch.cumsum(num_atoms_per_image_sqr, dim=0) - num_atoms_per_image_sqr - ) - index_sqr_offset = torch.repeat_interleave( - index_sqr_offset, num_atoms_per_image_sqr - ) - atom_count_sqr = ( - torch.arange(num_atom_pairs, device=device) - index_sqr_offset +def build_neighbor_topology( + pos: Tensor, + natoms: Tensor, + cell: Tensor, + cutoff: float, + skin: float = 0.0, + pbc: Optional[List[bool]] = None, + precision: torch.dtype = torch.float32, + numbers: Optional[Tensor] = None, + edge_source_first: bool = True, +) -> NeighborTopology: + cell = check_and_reshape_cell(cell) + radius = cutoff + max(skin, 0.0) + device = pos.device + pbc_array = np.asarray( + [True, True, True] if pbc is None else pbc, + dtype=bool, ) + natoms_np = _to_numpy_array(natoms).astype(np.int64) + pos_np = _to_numpy_array(pos).astype(np.float64, copy=False) + cell_np = _to_numpy_array(cell).astype(np.float64, copy=False) - # Compute the indices for the pairs of atoms (using division and mod) - # If the systems get too large this apporach could run into numerical precision issues - index1 = ( - torch.div( - atom_count_sqr, num_atoms_per_image_expand, rounding_mode="floor" - ) - ) + index_offset_expand - index2 = ( - atom_count_sqr % num_atoms_per_image_expand - ) + index_offset_expand - # Get the positions for each atom - pos1 = torch.index_select(atom_pos, 0, index1) - pos2 = torch.index_select(atom_pos, 0, index2) - - # Calculate required number of unit cells in each direction. - # Smallest distance between planes separated by a1 is - # 1 / ||(a2 x a3) / V||_2, since a2 x a3 is the area of the plane. - # Note that the unit cell volume V = a1 * (a2 x a3) and that - # (a2 x a3) / V is also the reciprocal primitive vector - # (crystallographer's definition). - #print(data.cell.shape) - cross_a2a3 = torch.cross(cell[:, 1], cell[:, 2], dim=-1) - cell_vol = torch.sum(cell[:, 0] * cross_a2a3, dim=-1, keepdim=True) - - if pbc[0]: - inv_min_dist_a1 = torch.norm(cross_a2a3 / cell_vol, dim=-1) - rep_a1 = torch.ceil(radius * inv_min_dist_a1) - else: - rep_a1 = cell.new_zeros(1) - - if pbc[1]: - cross_a3a1 = torch.cross(cell[:, 2], cell[:, 0], dim=-1) - inv_min_dist_a2 = torch.norm(cross_a3a1 / cell_vol, dim=-1) - rep_a2 = torch.ceil(radius * inv_min_dist_a2) + if numbers is None: + numbers_np = np.ones(pos_np.shape[0], dtype=np.int32) else: - rep_a2 = cell.new_zeros(1) - - if pbc[2]: - cross_a1a2 = torch.cross(cell[:, 0], cell[:, 1], dim=-1) - inv_min_dist_a3 = torch.norm(cross_a1a2 / cell_vol, dim=-1) - rep_a3 = torch.ceil(radius * inv_min_dist_a3) + numbers_np = _to_numpy_array(numbers).astype(np.int32, copy=False) + + edge_indices = [] + cell_offsets = [] + num_neighbors_image = [] + + atom_offset = 0 + for image_index, image_natoms in enumerate(natoms_np): + image_natoms = int(image_natoms) + image_slice = slice(atom_offset, atom_offset + image_natoms) + image_edge_index, image_offsets, image_neighbors = _build_single_image_topology_matscipy( + positions=pos_np[image_slice], + cell=cell_np[image_index], + numbers=numbers_np[image_slice], + radius=radius, + pbc=pbc_array, + edge_source_first=edge_source_first, + ) + if image_neighbors > 0: + image_edge_index = image_edge_index + atom_offset + edge_indices.append(torch.from_numpy(image_edge_index)) + cell_offsets.append(torch.from_numpy(image_offsets)) + num_neighbors_image.append(image_neighbors) + atom_offset += image_natoms + + if edge_indices: + edge_index = torch.cat(edge_indices, dim=1).to( + device=device, + dtype=torch.long, + ) + cell_offsets_tensor = torch.cat(cell_offsets, dim=0).to( + device=device, + dtype=torch.int32, + ) else: - rep_a3 = cell.new_zeros(1) - - # Take the max over all images for uniformity. This is essentially padding. - # Note that this can significantly increase the number of computed distances - # if the required repetitions are very different between images - # (which they usually are). Changing this to sparse (scatter) operations - # might be worth the effort if this function becomes a bottleneck. - max_rep = [int(rep_a1.max()), int(rep_a2.max()), int(rep_a3.max())] - - # Tensor of unit cells - cells_per_dim = [ - torch.arange(-rep, rep + 1, device=device, dtype=precision) - for rep in max_rep - ] - unit_cell = torch.cartesian_prod(cells_per_dim[0],cells_per_dim[1], cells_per_dim[2]) - num_cells = len(unit_cell) - unit_cell_per_atom = unit_cell.view(1, num_cells, 3).repeat( - len(index2), 1, 1 - ) - unit_cell = torch.transpose(unit_cell, 0, 1) - unit_cell_batch = unit_cell.view(1, 3, num_cells).expand( - batch_size, -1, -1 - ) - - # Compute the x, y, z positional offsets for each cell in each image - data_cell = torch.transpose(cell, 1, 2) - - pbc_offsets = torch.bmm(data_cell, unit_cell_batch) - pbc_offsets_per_atom = torch.repeat_interleave( - pbc_offsets, num_atoms_per_image_sqr, dim=0 + edge_index = torch.empty((2, 0), device=device, dtype=torch.long) + cell_offsets_tensor = torch.empty((0, 3), device=device, dtype=torch.int32) + neighbors = torch.tensor( + num_neighbors_image, + device=device, + dtype=torch.long, ) - - # Expand the positions and indices for the 9 cells - pos1 = pos1.view(-1, 3, 1).expand(-1, -1, num_cells) - pos2 = pos2.view(-1, 3, 1).expand(-1, -1, num_cells) - index1 = index1.view(-1, 1).repeat(1, num_cells).view(-1) - index2 = index2.view(-1, 1).repeat(1, num_cells).view(-1) - # Add the PBC offsets for the second atom - pos2 = pos2 + pbc_offsets_per_atom - - # Compute the squared distance between atoms - atom_distance_sqr = torch.sum((pos1 - pos2) ** 2, dim=1) - atom_distance_sqr = atom_distance_sqr.view(-1) - - # Remove pairs that are too far apart - mask_within_radius = torch.le(atom_distance_sqr, radius * radius) - # Remove pairs with the same atoms (distance = 0.0) - mask_not_same = torch.gt(atom_distance_sqr, 0.0001) - mask = torch.logical_and(mask_within_radius, mask_not_same) - index1 = torch.masked_select(index1, mask) - index2 = torch.masked_select(index2, mask) - unit_cell = torch.masked_select( - unit_cell_per_atom.view(-1, 3), mask.view(-1, 1).expand(-1, 3) - ) - unit_cell = unit_cell.view(-1, 3) - atom_distance_sqr = torch.masked_select(atom_distance_sqr, mask) - - mask_num_neighbors, num_neighbors_image = get_max_neighbors_mask( - natoms=natoms, - index=index1, - atom_distance=atom_distance_sqr, - max_num_neighbors_threshold=max_num_neighbors_threshold, - precision = precision + return NeighborTopology( + edge_index=edge_index, + cell_offsets=cell_offsets_tensor, + neighbors=neighbors, + reference_positions=pos.detach().clone(), + reference_cell=cell.detach().clone(), + cutoff=cutoff, + skin=max(skin, 0.0), ) - if not torch.all(mask_num_neighbors): - # Mask out the atoms to ensure each atom has at most max_num_neighbors_threshold neighbors - index1 = torch.masked_select(index1, mask_num_neighbors) - index2 = torch.masked_select(index2, mask_num_neighbors) - unit_cell = torch.masked_select( - unit_cell.view(-1, 3), mask_num_neighbors.view(-1, 1).expand(-1, 3) - ) - unit_cell = unit_cell.view(-1, 3) - - edge_index = torch.stack((index2, index1)) - return edge_index, unit_cell, num_neighbors_image - -def get_pbc_distances( +def _update_edge_geometry( pos: Tensor, + batch: Tensor, edge_index: Tensor, cell: Tensor, cell_offsets: Tensor, - neighbors: Tensor, - return_offsets: bool = False, - return_distance_vec: bool = False, - precision: torch.dtype = torch.float32 -): - row= edge_index[0] + precision: torch.dtype = torch.float32, + cutoff: Optional[float] = None, +) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + row = edge_index[0] col = edge_index[1] + edge_batch = batch[col] + cell_per_edge = cell[edge_batch] + offsets = ( + cell_offsets.to(precision) + .view(-1, 1, 3) + .bmm(cell_per_edge.to(precision)) + .view(-1, 3) + ) + distance_vectors = pos[row] - pos[col] + offsets + distances = distance_vectors.norm(dim=-1, p=2) + valid_mask = distances > 0 + if cutoff is not None: + valid_mask = torch.logical_and(valid_mask, distances <= cutoff) + edge_index = edge_index[:, valid_mask] + cell_offsets = cell_offsets[valid_mask] + distances = distances[valid_mask] + distance_vectors = distance_vectors[valid_mask] + return edge_index, cell_offsets, distances, distance_vectors + + +def graph_from_neighbor_topology( + pos: Tensor, + z: Tensor, + natoms: Tensor, + batch: Tensor, + topology: NeighborTopology, + cell: Optional[Tensor] = None, + displacement: Optional[Tensor] = None, + cutoff: Optional[float] = None, + dtype: torch.dtype = torch.float32, +) -> GraphData: + precision = dtype + pos = pos.to(precision) + z = z.long() + cell = check_and_reshape_cell(cell) + edge_index, cell_offsets, dist, vecs = _update_edge_geometry( + pos=pos, + batch=batch, + edge_index=topology.edge_index, + cell=cell, + cell_offsets=topology.cell_offsets, + precision=precision, + cutoff=topology.cutoff if cutoff is None else cutoff, + ) + return GraphData( + pos=pos, + z=z, + natoms=natoms, + batch=batch, + edge_index=edge_index, + edge_attr=dist, + edge_vec=vecs, + cell=cell, + cell_offsets=cell_offsets, + displacement=displacement, + ) - distance_vectors = pos[row] - pos[col] - - # correct for pbc - neighbors = neighbors.to(cell.device) - cell = torch.repeat_interleave(cell, neighbors, dim=0) - offsets = cell_offsets.to(precision).view(-1, 1, 3).bmm(cell.to(precision)).view(-1, 3) - distance_vectors += offsets - - # compute distances - distances = distance_vectors.norm(dim=-1 , p=2) - - # redundancy: remove zero distances - nonzero_idx = torch.arange(len(distances), device=distances.device)[ - distances != 0 - ] - edge_index = edge_index[:, nonzero_idx] - distances = distances[nonzero_idx] - - out = { - "edge_index": edge_index, - "distances": distances, - } - - if return_distance_vec: - out["distance_vec"] = distance_vectors[nonzero_idx] - - if return_offsets: - out["offsets"] = offsets[nonzero_idx] - - return out # Borrowed from MACE -def get_symmetric_displacement( +def get_symmetric_displacement( positions: torch.Tensor, cell: Optional[torch.Tensor], num_graphs: int, @@ -339,18 +250,18 @@ def get_symmetric_displacement( dtype=positions.dtype, device=positions.device, ) - + displacement = torch.zeros( (num_graphs, 3, 3), dtype=positions.dtype, device=positions.device, - ) - + ) + displacement.requires_grad_(True) symmetric_displacement = 0.5 * ( displacement + displacement.transpose(-1, -2) ) - + positions = positions + torch.einsum( "be,bec->bc", positions, symmetric_displacement[batch] ) @@ -359,6 +270,7 @@ def get_symmetric_displacement( cell.view(-1, 3) return positions, cell, displacement + def process_positions_and_edges( pos: Tensor, z: Tensor, @@ -372,132 +284,45 @@ def process_positions_and_edges( dtype: torch.dtype = torch.float32 ) -> GraphData: """ - Process atomic positions and compute edges with optional PBC support. - We found that non-pbc graph is not compatible with jit compile, so we don't support that for now, please create a large cell if you want to do non-pbc calculation. - Args: - data: Input data object containing positions, batch info, and other attributes - compute_forces: Boolean flag for force computation - compute_stress: Boolean flag for stress computation - use_pbc: Boolean flag for periodic boundary conditions - cutoff: Cutoff radius for neighbor search - dtype: torch dtype precision - - Returns: - Data: Data object containing processed attributes + Process atomic positions and compute edges with PBC support. + + Non-PBC mode is not supported directly; create a large vacuum cell instead. """ precision = dtype pos = pos.to(precision) z = z.long() - + if compute_stress: - pos, cell, displacement = get_symmetric_displacement( - pos, cell, num_graphs=int(torch.max(batch))+1, batch=batch + pos, cell, num_graphs=int(torch.max(batch)) + 1, batch=batch ) else: displacement = None - + cell = check_and_reshape_cell(cell) - - if use_pbc and cell is not None: - - edge_index, cell_offsets, neighbors = radius_graph_pbc( - pos, natoms, cell, cutoff, max_num_neighbors_threshold=50, precision = precision - ) - #print(edge_index) - out = get_pbc_distances( - pos, - edge_index, - cell, - cell_offsets, - neighbors, - return_distance_vec=True, - precision = precision + + if not use_pbc or cell is None: + raise ValueError( + "Non-PBC mode is not supported; please create a large vacuum cell." ) - edge_index = out["edge_index"] - dist = out["distances"] - vecs = out["distance_vec"] - - else: - raise ValueError(f"None PBC is not supporting yet, as radius graph is not compilable with jit") - - - return GraphData( + + topology = build_neighbor_topology( pos=pos, - z=z, natoms=natoms, - batch=batch, - edge_index=edge_index, - edge_attr=dist, - edge_vec=vecs, cell=cell, - cell_offsets=cell_offsets, - displacement=displacement, + cutoff=cutoff, + skin=0.0, + precision=precision, + numbers=z, ) - - -def _process_positions_and_edges( - pos: Tensor, - z: Tensor, - natoms: Tensor, - batch: Tensor, - cell: Optional[Tensor] = None, - compute_stress: bool = False, - compute_forces: bool = False, - use_pbc: bool = False, - cutoff: float = 5.0, - dtype: torch.dtype = torch.float32 -) -> GraphData: - """ - Process atomic positions and compute edges with optional PBC support. - Added 100 ghost atoms that are not connected to any nodes. - """ - precision = dtype - pos = pos.to(precision) - z = z.long() - num_graphs = int(torch.max(batch)) + 1 # 获取图的数量 - - if compute_stress: - new_pos, cell, displacement = get_symmetric_displacement( - pos, cell, num_graphs=int(torch.max(batch))+1, batch=batch - ) - else: - displacement = None - - cell = check_and_reshape_cell(cell) - - if use_pbc and cell is not None: - edge_index, cell_offsets, neighbors = radius_graph_pbc( - pos, natoms, cell, cutoff, max_num_neighbors_threshold=50, precision=precision - ) - new_pos = pos - new_z = z - new_batch = batch - out = get_pbc_distances( - new_pos, - edge_index, - cell, - cell_offsets, - neighbors, - return_distance_vec=True, - precision=precision - ) - edge_index = out["edge_index"] - dist = out["distances"] - vecs = out["distance_vec"] - else: - raise ValueError("None PBC is not supporting yet, as radius graph is not compilable with jit") - - - return GraphData( - pos=new_pos, - z=new_z, - natoms=new_natoms, - batch=new_batch, - edge_index=edge_index, - edge_attr=dist, - edge_vec=vecs, + return graph_from_neighbor_topology( + pos=pos, + z=z, + natoms=natoms, + batch=batch, + topology=topology, cell=cell, - cell_offsets=cell_offsets, displacement=displacement, + cutoff=cutoff, + dtype=precision, ) diff --git a/alphanet/models/model.py b/alphanet/models/model.py index 068e79f..5acc385 100644 --- a/alphanet/models/model.py +++ b/alphanet/models/model.py @@ -3,8 +3,10 @@ from torch import Tensor from typing import Optional from alphanet.models.alphanet import AlphaNet -from alphanet.models.graph import process_positions_and_edges +from alphanet.models.graph import GraphData, process_positions_and_edges from alphanet.config import AlphaConfig + + class AlphaNetWrapper(torch.nn.Module): def __init__( self, @@ -17,27 +19,67 @@ def __init__( self.compute_stress = config.compute_stress self.use_pbc = config.use_pbc self.precision = torch.float32 if config.dtype == "32" else torch.float64 - def forward(self, + + def _resolve_compute_flags( + self, + compute_forces: Optional[bool], + compute_stress: Optional[bool], + ): + resolved_forces = self.compute_forces if compute_forces is None else compute_forces + resolved_stress = self.compute_stress if compute_stress is None else compute_stress + return resolved_forces, resolved_stress + + def forward( + self, pos: Tensor, z: Tensor, batch: Tensor, natoms: Tensor, cell: Optional[Tensor] = None, - prefix: str = 'infer'): - + prefix: str = 'infer', + compute_forces: Optional[bool] = None, + compute_stress: Optional[bool] = None, + return_atom_energy: bool = False): + compute_forces, compute_stress = self._resolve_compute_flags( + compute_forces, + compute_stress, + ) processed_data = process_positions_and_edges( pos=pos, z=z, natoms=natoms, batch=batch, cell=cell, - compute_forces=self.compute_forces, - compute_stress=self.compute_stress, + compute_forces=compute_forces, + compute_stress=compute_stress, use_pbc=self.use_pbc, cutoff=self.cutoff, dtype=self.precision ) - - output = self.model(processed_data, prefix) - - return output \ No newline at end of file + return self.forward_graph( + processed_data, + prefix=prefix, + compute_forces=compute_forces, + compute_stress=compute_stress, + return_atom_energy=return_atom_energy, + ) + + def forward_graph( + self, + graph_data: GraphData, + prefix: str = "infer", + compute_forces: Optional[bool] = None, + compute_stress: Optional[bool] = None, + return_atom_energy: bool = False, + ): + compute_forces, compute_stress = self._resolve_compute_flags( + compute_forces, + compute_stress, + ) + return self.model( + graph_data, + prefix, + compute_forces=compute_forces, + compute_stress=compute_stress, + return_atom_energy=return_atom_energy, + ) diff --git a/alphanet/models/zbl.py b/alphanet/models/zbl.py new file mode 100644 index 0000000..4304bd2 --- /dev/null +++ b/alphanet/models/zbl.py @@ -0,0 +1,76 @@ +# 文件路径: AlphaNet-lammps/alphanet/models/zbl.py +import torch +from torch import nn +import math + +class ZBLPotential(nn.Module): + def __init__(self, config): + super().__init__() + # 默认参数 (通用拟合值) + default_w = [0.187, 0.3769, 0.189, 0.081, 0.003, 0.037, 0.0546, 0.0715] + default_b = [3.20, 1.10, 0.102, 0.958, 1.28, 1.14, 1.69, 5.0] + + # 从配置读取,支持微调 ZBL 参数 + w = getattr(config, 'zbl_w', default_w) + if w is None: w = default_w + + b = getattr(config, 'zbl_b', default_b) + if b is None: b = default_b + + gamma = getattr(config, 'zbl_gamma', 1.001) + alpha = getattr(config, 'zbl_alpha', 0.6032) + + self.register_buffer('fzbl_w', torch.tensor(w, dtype=torch.get_default_dtype())) + self.register_buffer('fzbl_b', torch.tensor(b, dtype=torch.get_default_dtype())) + + # 归一化权重 + with torch.no_grad(): + w_tensor = self.fzbl_w + w_tensor = w_tensor.clamp(min=0.0) + w_tensor = w_tensor / (w_tensor.sum() + 1e-12) + self.fzbl_w.copy_(w_tensor) + + self.register_buffer('fzbl_gamma', torch.tensor(gamma, dtype=torch.get_default_dtype())) + self.register_buffer('fzbl_alpha', torch.tensor(alpha, dtype=torch.get_default_dtype())) + + # 物理常数 + self.register_buffer('fzbl_E2', torch.tensor(14.399645478425, dtype=torch.get_default_dtype())) # eV·Å + self.register_buffer('fzbl_A0', torch.tensor(0.529177210903, dtype=torch.get_default_dtype())) # Å + + # 平滑截断参数 + self.r_cut = 1.0 + + def forward(self, dist: torch.Tensor, z_i: torch.Tensor, z_j: torch.Tensor) -> torch.Tensor: + """ + Pair ZBL Energy V_ij + """ + r_e = dist + + # 确保计算精度一致 + dtype = r_e.dtype + w = self.fzbl_w.to(dtype=dtype) + b = self.fzbl_b.to(dtype=dtype) + gamma = self.fzbl_gamma.to(dtype=dtype) + alpha = self.fzbl_alpha.to(dtype=dtype) + E2 = self.fzbl_E2.to(dtype=dtype) + A0 = self.fzbl_A0.to(dtype=dtype) + + denom = torch.pow(z_j, alpha) + torch.pow(z_i, alpha) + denom = torch.clamp(denom, min=1e-12) + a_vals = gamma * 0.8854 * A0 / denom + x = r_e / a_vals + + # V(r) = (Z1*Z2*e^2/r) * phi(x) + # phi(x) = sum(w * exp(-b*x)) + exp_terms = torch.exp(-x.unsqueeze(1) * b.unsqueeze(0)) # [Edges, M] + phi_vals = exp_terms.matmul(w) # [Edges] + + V_edge = (z_j * z_i * E2) * (phi_vals / r_e) + + # Cutoff smoothing (0 to r_cut) + + xrc = (r_e / self.r_cut).clamp(min=0.0, max=1.0) + c = 0.5 * (torch.cos(math.pi * xrc) + 1.0) + c = torch.where(r_e >= self.r_cut, torch.zeros_like(c), c) + + return V_edge * c \ No newline at end of file diff --git a/alphanet/models/zbl_jax.py b/alphanet/models/zbl_jax.py new file mode 100644 index 0000000..041d053 --- /dev/null +++ b/alphanet/models/zbl_jax.py @@ -0,0 +1,48 @@ +import jax.numpy as jnp +import jax + +def zbl_interaction(dist, z_i, z_j, w, b, gamma, alpha, E2, A0, r_cut=1.0): + """ + 计算 ZBL Pair Potential (JAX functional implementation) + """ + r_e = dist + + # 防止除零 + denom = jnp.power(z_j, alpha) + jnp.power(z_i, alpha) + denom = jnp.clip(denom, a_min=1e-12) + + a_vals = gamma * 0.8854 * A0 / denom + x = r_e / a_vals + + # phi(x) = sum(w * exp(-b*x)) + # w: (M,), b: (M,), x: (Edges,) + # exp(- x[:, None] * b[None, :]) -> (Edges, M) + exp_terms = jnp.exp(-x[:, jnp.newaxis] * b[jnp.newaxis, :]) + phi_vals = jnp.dot(exp_terms, w) + + V_edge = (z_j * z_i * E2) * (phi_vals / r_e) + + # Cutoff smoothing + xrc = jnp.clip(r_e / r_cut, 0.0, 1.0) + c = 0.5 * (jnp.cos(jnp.pi * xrc) + 1.0) + c = jnp.where(r_e >= r_cut, jnp.zeros_like(c), c) + + return V_edge * c + +def get_default_zbl_params(dtype=jnp.float32): + fzbl_w = jnp.array([0.187, 0.3769, 0.189, 0.081, 0.003, 0.037, 0.0546, 0.0715], dtype=dtype) + # Normalize + fzbl_w = jnp.clip(fzbl_w, a_min=0.0) + fzbl_w = fzbl_w / (jnp.sum(fzbl_w) + 1e-12) + + fzbl_b = jnp.array([3.20, 1.10, 0.102, 0.958, 1.28, 1.14, 1.69, 5.0], dtype=dtype) + + params = { + "w": fzbl_w, + "b": fzbl_b, + "gamma": jnp.array(1.001, dtype=dtype), + "alpha": jnp.array(0.6032, dtype=dtype), + "E2": jnp.array(14.399645478425, dtype=dtype), + "A0": jnp.array(0.529177210903, dtype=dtype) + } + return params \ No newline at end of file diff --git a/alphanet/mul_trainer.py b/alphanet/mul_trainer.py index a3085f1..6dc6627 100644 --- a/alphanet/mul_trainer.py +++ b/alphanet/mul_trainer.py @@ -1,5 +1,5 @@ import torch -from torch.optim import Adam, AdamW +from torch.optim import Adam, AdamW, RAdam from torch_geometric.data import DataLoader from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau, CosineAnnealingLR import pytorch_lightning as pl @@ -177,7 +177,19 @@ def test_step(self, batch, batch_idx): return loss def configure_optimizers(self): - optimizer = Adam(self.parameters(), lr=self.config.train.lr, weight_decay=self.config.train.weight_decay) + + opt_name = self.config.train.optimizer.lower() + lr = self.config.train.lr + weight_decay = self.config.train.weight_decay + + if opt_name == 'adam': + optimizer = Adam(self.parameters(), lr=lr, weight_decay=weight_decay) + elif opt_name == 'adamw': + optimizer = AdamW(self.parameters(), lr=lr, weight_decay=weight_decay) + elif opt_name == 'radam': + optimizer = RAdam(self.parameters(), lr=lr, weight_decay=weight_decay) + else: + raise ValueError(f"Unknown optimizer: {opt_name}") if self.config.train.scheduler == 'steplr': scheduler = StepLR(optimizer, step_size=self.config.train.lr_decay_step_size, gamma=self.config.train.lr_decay_factor) diff --git a/alphanet/train.py b/alphanet/train.py index ea32865..06fd768 100644 --- a/alphanet/train.py +++ b/alphanet/train.py @@ -5,45 +5,89 @@ from alphanet.models.model import AlphaNetWrapper from alphanet.mul_trainer import Trainer -def run_training(config1,config2): - - train_dataset, valid_dataset, test_dataset = get_pic_datasets(root='dataset/', name=config1.dataset_name,config = config1) +def run_training(config1, runtime_config): + + train_dataset, valid_dataset, test_dataset = get_pic_datasets( + root='dataset/', + name=config1.dataset_name, + config=config1 + ) + force_std = torch.std(train_dataset.data.force).item() - - energy_peratom = torch.sum(train_dataset.data.y).item()/torch.sum(train_dataset.data.natoms).item() + if hasattr(train_dataset.data, 'y') and train_dataset.data.y is not None: + energy_peratom = torch.sum(train_dataset.data.y).item() / torch.sum(train_dataset.data.natoms).item() + else: + energy_peratom = 0.0 + config1.a = force_std config1.b = energy_peratom - #print(config1.a, config1.b) + model = AlphaNetWrapper(config1) - #print(model.model.a, model.model.b) + if config1.dtype == "64": - model = model.double() - #strategy = DDPStrategy(num_nodes=config["hardware"]["num_nodes"]) if config["hardware"]["num_nodes"] > 1 else "auto" + model = model.double() + + + if runtime_config.get("finetune_path"): + ft_path = runtime_config["finetune_path"] + print(f"🔨 Finetuning mode: Loading weights from {ft_path}...") + + + try: + ckpt = torch.load(ft_path, map_location='cpu') + except FileNotFoundError: + raise FileNotFoundError(f"Finetune checkpoint not found at: {ft_path}") + + + if isinstance(ckpt, dict) and 'state_dict' in ckpt: + state_dict = ckpt['state_dict'] + else: + state_dict = ckpt + new_state_dict = state_dict + + missing, unexpected = model.load_state_dict(new_state_dict, strict=False) + + if len(missing) > 0: + print(f" Warning: Missing keys ({len(missing)}): {missing[:3]} ...") + if len(unexpected) > 0: + print(f" Warning: Unexpected keys ({len(unexpected)}): {unexpected[:3]} ...") + + print(" ✅ Weights loaded successfully. Optimizer states reset for finetuning.") + else: + print("🆕 Training from scratch (random initialization).") + + checkpoint_callback = ModelCheckpoint( dirpath=config1.train.save_dir, filename='{epoch}-{val_loss:.4f}-{val_energy_loss:.4f}-{val_force_loss:.4f}', save_top_k=-1, - every_n_epochs=1, - save_on_train_epoch_end=True, + every_n_epochs=1, + save_on_train_epoch_end=True, monitor='val_loss', mode='min' ) + trainer = pl.Trainer( - devices=config2["num_devices"], - num_nodes=config2["num_nodes"], - strategy='ddp_find_unused_parameters_true', - accelerator="gpu" if config2["num_devices"] > 0 else "cpu", + devices=runtime_config["num_devices"], + num_nodes=runtime_config["num_nodes"], + strategy='ddp_find_unused_parameters_true', + accelerator="gpu" if runtime_config["num_devices"] > 0 and torch.cuda.is_available() else "cpu", max_epochs=config1.epochs, callbacks=[checkpoint_callback], enable_checkpointing=True, - gradient_clip_val=0.5, + gradient_clip_val=0.1, default_root_dir=config1.train.save_dir, accumulate_grad_batches=config1.accumulation_steps, - limit_val_batches=100, ) + + pl_module = Trainer(config1, model, train_dataset, valid_dataset, test_dataset) - model = Trainer(config1, model, train_dataset, valid_dataset, test_dataset) - trainer.fit(model, ckpt_path=config2["ckpt_path"] if config2["resume"] else None) + ckpt_path_arg = runtime_config["ckpt_path"] if runtime_config["resume"] else None + + if runtime_config["resume"] and ckpt_path_arg: + print(f"🔄 Resuming training from checkpoint: {ckpt_path_arg}") + + trainer.fit(pl_module, ckpt_path=ckpt_path_arg) \ No newline at end of file diff --git a/example.json b/example.json deleted file mode 100644 index 5c0511c..0000000 --- a/example.json +++ /dev/null @@ -1,47 +0,0 @@ -{ - "data": { - "root": "dataset/", - "dataset_name": "t1", - "target": "energy_force", - "train_dataset": "t1", - "train_size": 2000, - "valid_dataset": "t1", - "valid_size": 2000, - "test_dataset": "t1", - "seed": 42 - }, - "model": { - "name": "Alphanet", - "num_layers": 3, - "num_targets": 1, - "output_dim": 1, - "compute_forces": true, - "use_pbc": true, - "has_dropout_flag": false, - "hidden_channels": 256, - "cutoff": 4, - "num_radial": 8 - }, - "train": { - "epochs": 200, - "batch_size": 24, - "accumulation_steps": 1, - "vt_batch_size": 24, - "lr": 0.0001, - "lr_decay_factor": 0.8, - "lr_decay_step_size": 80, - "weight_decay": 0, - "save_dir": "t1", - "log_dir": "", - "disable_tqdm": false, - "scheduler": "cosineannealinglr", - "device": "cuda", - "energy_loss": "mae", - "force_loss": "mae", - "energy_metric": "mae", - "force_metric": "mae", - "energy_coef": 4.0, - "force_coef": 100.0, - "eval_steps": 1 - } -} diff --git a/example_for_help.json b/example_for_help.json deleted file mode 100644 index 43022be..0000000 --- a/example_for_help.json +++ /dev/null @@ -1,48 +0,0 @@ -{ - "data": { - "root": "dataset/", // Root directory of the dataset - "dataset_name": "t2", // Dataset name (Just a name) - "target": "energy_force", // Task: predict energy and forces - "train_dataset": "t2", // Training dataset name (Exactly the name of dataset folder) - "train_size": 2000, // Number of training samples - "valid_dataset": "t2_valid", // Validation dataset name (Exactly the name of dataset folder) - "valid_size": 2000, // Number of validation samples - "test_dataset": "t2_valid", // Test dataset name (Exactly the name of dataset folder) - "seed": 42 // Random seed for reproducibility - }, - "model": { - "name": "Alphanet", // Model name - "num_layers": 3, // Number of layers - "num_targets": 1, // Number of target variables - "output_dim": 1, // Output dimension - "compute_forces": true, // Predict forces - "compute_stress": true, // Predict stress - "use_pbc": true, // Use periodic boundary conditions - "hidden_channels": 128, // Hidden layer size - "cutoff": 4, // Cutoff radius for interactions - "num_radial": 8 // Number of radial basis functions - }, - "train": { - "epochs": 200, // Total training epochs - "batch_size": 4, // Batch size - "accumulation_steps": 1, // Gradient accumulation steps - "vt_batch_size": 4, // Batch size for validation/testing - "lr": 0.0001, // Initial learning rate - "lr_decay_factor": 0.8, // Learning rate decay factor - "lr_decay_step_size": 80, // Step size for learning rate decay - "weight_decay": 0, // L2 regularization weight - "save_dir": "t2_15", // Model save directory - "log_dir": "", // Log directory (empty if none) - "disable_tqdm": false, // Disable progress bar - "scheduler": "cosineannealinglr",// Learning rate scheduler - "device": "cuda", // Device for training (GPU) - "energy_loss": "mae", // Loss function for energy (MAE) - "force_loss": "mae", // Loss function for forces (MAE) - "energy_metric": "mae", // Metric for energy evaluation - "force_metric": "mae", // Metric for force evaluation - "energy_coef": 4.0, // Coefficient for energy loss - "force_coef": 100.0, // Coefficient for force loss - "eval_steps": 1 // Evaluation step interval - } -} - diff --git a/lmp_requirements.txt b/lmp_requirements.txt new file mode 100644 index 0000000..3eb9aa4 --- /dev/null +++ b/lmp_requirements.txt @@ -0,0 +1,6 @@ +cython + +pytest +attrs +jinja2 +urllib3 diff --git a/mliap_lammps.md b/mliap_lammps.md new file mode 100644 index 0000000..e857bf2 --- /dev/null +++ b/mliap_lammps.md @@ -0,0 +1,159 @@ +# LAMMPS Installation Guide: ML-IAP with KOKKOS (CUDA) Support + +This guide details the process for building LAMMPS with the Machine Learning Interatomic Potential (`ML-IAP`) package, enabled with Python bindings and KOKKOS acceleration for NVIDIA GPUs. + +## Prerequisites +* **Git** & **CMake** +* **C++ Compiler** (compatible with your MPI/CUDA version) +* **MPI Implementation** (e.g., OpenMPI, MPICH) +* **CUDA Toolkit** (version 11.x or 12.x) +* **Python 3.x** + +--- + +## 1. Get the Source Code +Clone the repository and check out the specific commit hash used for this build to ensure reproducibility. + +```bash +git clone https://github.com/lammps/lammps.git +cd lammps + +# Checkout specific commit for stability/reproducibility +git checkout ccca772 + +``` + +## 2. Prepare the Build Environment + +We will use an out-of-source build directory to keep the source tree clean. We will also start with the `kokkos-cuda` CMake preset. + +```bash +mkdir build-mliap +cd build-mliap + +# Copy the KOKKOS CUDA preset to the build directory +cp ../cmake/presets/kokkos-cuda.cmake ./ + +``` + +### ⚠️ Important Configuration Step + +Before generating the build files, you must edit `kokkos-cuda.cmake` to match your specific GPU architecture. + +1. Open `kokkos-cuda.cmake` in a text editor. +2. Locate the architecture flag (e.g., `-DKokkos_ARCH_...`). +3. Change it to match your GPU (e.g., `Kokkos_ARCH_VOLTA70`, `Kokkos_ARCH_AMPERE80`, `Kokkos_ARCH_HOPPER90`). +* *Reference:* [LAMMPS KOKKOS Build Options](https://docs.lammps.org/Build_extras.html#kokkos) + + + +--- + +## 3. Configure and Compile + +Run CMake to configure the build with ML-IAP, SNAP, and Python support enabled. + +```bash +cmake -C kokkos-cuda.cmake \ + -D CMAKE_BUILD_TYPE=Release \ + -D CMAKE_INSTALL_PREFIX=$(pwd) \ + -D BUILD_MPI=ON \ + -D PKG_ML-IAP=ON \ + -D PKG_ML-SNAP=ON \ + -D MLIAP_ENABLE_PYTHON=ON \ + -D PKG_PYTHON=ON \ + -D BUILD_SHARED_LIBS=ON \ + ../cmake + +``` + +**Key Flags Explained:** + +* `PKG_ML-IAP=ON`: Enables the Machine Learning Interatomic Potentials package. +* `MLIAP_ENABLE_PYTHON=ON`: Allows ML-IAP to call Python functions (essential for PyTorch/PyG models). +* `BUILD_SHARED_LIBS=ON`: Builds LAMMPS as a shared library (`.so`), required for the Python module. + +### Compile + +Compile the code using multiple cores (adjust `-j 8` based on your CPU cores). + +```bash +make -j 8 + +``` + +--- + +## 4. Python Environment Setup + +Install the LAMMPS Python interface and the necessary dependencies. + +```bash +# Install the lammps python module into your current environment +make install-python + +# Install dependencies for your ML model +cd ../../ +pip install -r lmp_requirements.txt + +# Install CuPy (Ensure the version matches your CUDA version) +# For CUDA 12.x: +pip install cupy-cuda12x +# For CUDA 11.x, use: pip install cupy-cuda11x + +``` + +--- + +## 5. Running LAMMPS + +Below are the commands to run LAMMPS using the KOKKOS accelerator package on GPUs. +### Convert the checkpoint: +```bash +python alphanet/create_lammps_model.py \ + --config ./pretrained/OMA/oma.json \ + --checkpoint ./pretrained/OMA/alex_0410.ckpt \ + --output ./alphanet_lammps.pt \ + --dtype float64 \ + --device cpu \ + +``` +### Input file: + +Necessary settings: +```bash +units metal +atom_style atomic +newton on +pair_style mliap unified your_converted.pt 0 +``` + +### Single GPU Execution + +Run on 1 GPU without MPI. + +```bash +lmp -k on g 1 -sf kk -pk kokkos newton on neigh half gpu/aware on -in test.in + +``` + +### Multi-GPU Execution + +Run on 2 GPUs using MPI. + +```bash +mpirun -np 2 lmp -k on g 2 -sf kk -pk kokkos newton on neigh half gpu/aware on -in sl.in + +``` + +### Runtime Flags Breakdown + +* `-k on g X`: Enable KOKKOS and use **X** GPUs per node. +* `-sf kk`: **Suffix KOKKOS**. Automatically appends `/kk` to styles in the input script (e.g., `pair_style` becomes `pair_style/kk`). +* `-pk kokkos`: Modifies global KOKKOS parameters: +* `newton on`: Turns on Newton's 3rd law optimizations (often faster for GPUs). +* `neigh half`: Uses a half-neighbor list (often more efficient on GPUs). +* `gpu/aware on`: Optimizes MPI communication if using CUDA-aware MPI. + + + diff --git a/mul_train.py b/mul_train.py deleted file mode 100644 index 8457a23..0000000 --- a/mul_train.py +++ /dev/null @@ -1,68 +0,0 @@ -import torch -import pytorch_lightning as pl -from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping -from alphanet.data import get_pic_datasets -from alphanet.models.model import AlphaNetWrapper -from alphanet.config import All_Config -from alphanet.mul_trainer import Trainer -import os -def main(): - config = All_Config().from_json("OC2M-train.json") - train_dataset, valid_dataset, test_dataset = train_dataset, valid_dataset, test_dataset = get_pic_datasets(root='dataset/', name=config.dataset_name,config = config) - force_std = torch.std(train_dataset.data.force).item() - ENERGY_MEAN_TOTAL = 0 - FORCE_MEAN_TOTAL = 0 - NUM_ATOM = None - - for data in valid_dataset: - energy = data.y - force = data.force - NUM_ATOM = force.size()[0] - energy_mean = energy / NUM_ATOM - ENERGY_MEAN_TOTAL += energy_mean - - ENERGY_MEAN_TOTAL /= len(train_dataset) - - config.a = force_std - config.b = ENERGY_MEAN_TOTAL - - model = AlphaNetWrapper(config) - - checkpoint_callback = ModelCheckpoint( - dirpath=config.train.save_dir, - filename='{epoch}-{val_loss:.4f}-{val_energy_mae:.4f}-{val_force_mae:.4f}', - save_top_k=-1, - every_n_epochs=1, - save_on_train_epoch_end=True, - monitor='val_loss', - mode='min' - ) - - early_stopping_callback = EarlyStopping( - monitor='val_loss', - patience= 50, - mode='min' - ) - - trainer = pl.Trainer( - devices=3, - num_nodes=1, - limit_train_batches=40000, - accelerator='auto', - #inference_mode=False, - - strategy='ddp_find_unused_parameters_true', - max_epochs=config.train.epochs, - callbacks=[checkpoint_callback, early_stopping_callback], - default_root_dir=config.train.save_dir, - logger=pl.loggers.TensorBoardLogger(config.train.log_dir), - gradient_clip_val=0.5, - accumulate_grad_batches=config.train.accumulation_steps - ) - - model = Trainer(config, model, train_dataset, valid_dataset, test_dataset) - trainer.fit(model)#, ckpt_path = ckpt) - trainer.test() - -if __name__ == '__main__': - main() diff --git a/pretrained/AQCAT25/README.md b/pretrained/AQCAT25/README.md deleted file mode 100644 index 3a7364e..0000000 --- a/pretrained/AQCAT25/README.md +++ /dev/null @@ -1,23 +0,0 @@ -# AlphaNet-AQCAT25 - -A model trained on the [AQCAT25](https://www.sandboxaq.com/aqcat25), mainly used for surfaces adsorbtion and reactions. The model is trained on the total energies and forces of trainiing set and slabs data. - -## Model Details - -* **Parameters:** Approximately 6.9M - -## Access the Model - -The following resources are available in the `pretrained_models/AQCAT25` path: -* **Model Configuration:** `aqcat.json` -* **Model state\_dict:** Pre-trained weights `aqcat_1021.ckpt` - - -## Performance -| Mae | Value | Unit/Description | -| :--- | :--- | :--- | -| test_id | 0.010,0.088 | eV/atom , eV/$ \AA $| -| test_ood_ads | 0.010,0.082 | eV/atom , eV/$ \AA $ | -| test_ood_both | 0.024, 0.097 | eV/atom , eV/$ \AA $ | -| test_ood_mat | 0.0186, 0.101 | eV/atom , eV/$ \AA $ | -| test_ood_slabs | 0.025, 0.091 | eV/atom , eV/$ \AA $ | diff --git a/pretrained/AQCAT25/aqcat.json b/pretrained/AQCAT25/aqcat.json deleted file mode 100644 index 80e876d..0000000 --- a/pretrained/AQCAT25/aqcat.json +++ /dev/null @@ -1,51 +0,0 @@ -{ - "data": { - "root": "dataset/", - "dataset_name": "t1", - "target": "energy_force", - "train_dataset": "aqcat", - "valid_dataset": "test_ood_ads", - "test_dataset": "test_ood_mat", - "seed": 42 - }, - "model": { - "name": "Alphanet", - "num_layers": 4, - "num_targets": 1, - "output_dim": 1, - "compute_forces": true, - "compute_stress": false, - "main_chi1": 64, - "mp_chi1": 48, - "chi2": 32, - "use_pbc": true, - "has_dropout_flag": false, - "hidden_channels": 128, - "cutoff": 6, - "num_radial": 8 - }, - "train": { - "epochs": 500, - "batch_size":32, - "accumulation_steps": 1, - "vt_batch_size": 32, - "lr": 0.0002, - "lr_decay_factor": 0.9, - "lr_decay_step_size": 50000, - "weight_decay": 0.0001, - "save_dir": "./dac", - "log_dir": "", - "disable_tqdm": false, - "scheduler": "steplr", - "device": "cuda", - "energy_loss": "mae", - "force_loss": "mae", - "stress_loss": "mae", - "energy_metric": "mae", - "force_metric": "mae", - "energy_coef": 4.0, - "stress_coef": 0.5, - "force_coef": 100.0, - "eval_steps": 1 - } -} diff --git a/pretrained/AQCAT25/aqcat_1021.ckpt b/pretrained/AQCAT25/aqcat_1021.ckpt deleted file mode 100644 index d79ac9a..0000000 Binary files a/pretrained/AQCAT25/aqcat_1021.ckpt and /dev/null differ diff --git a/pretrained/AQCAT25/haiku_model_converted/conversion_map.txt b/pretrained/AQCAT25/haiku_model_converted/conversion_map.txt deleted file mode 100644 index 1371d3a..0000000 --- a/pretrained/AQCAT25/haiku_model_converted/conversion_map.txt +++ /dev/null @@ -1,160 +0,0 @@ -Flax Key -> Haiku Key Mapping --------------------------------------------------------------------------------- -params/a -> alpha_net_hiku/a -params/b -> alpha_net_hiku/b -params/ftes_0/vec_proj/kernel -> alpha_net_hiku/~/fte/~/linear/w -params/ftes_0/xvec_proj/layers_0/bias -> alpha_net_hiku/~/fte/~/linear_1/b -params/ftes_0/xvec_proj/layers_0/kernel -> alpha_net_hiku/~/fte/~/linear_1/w -params/ftes_0/xvec_proj/layers_2/bias -> alpha_net_hiku/~/fte/~/linear_2/b -params/ftes_0/xvec_proj/layers_2/kernel -> alpha_net_hiku/~/fte/~/linear_2/w -params/ftes_1/vec_proj/kernel -> alpha_net_hiku/~/fte_1/~/linear/w -params/ftes_1/xvec_proj/layers_0/bias -> alpha_net_hiku/~/fte_1/~/linear_1/b -params/ftes_1/xvec_proj/layers_0/kernel -> alpha_net_hiku/~/fte_1/~/linear_1/w -params/ftes_1/xvec_proj/layers_2/bias -> alpha_net_hiku/~/fte_1/~/linear_2/b -params/ftes_1/xvec_proj/layers_2/kernel -> alpha_net_hiku/~/fte_1/~/linear_2/w -params/ftes_2/vec_proj/kernel -> alpha_net_hiku/~/fte_2/~/linear/w -params/ftes_2/xvec_proj/layers_0/bias -> alpha_net_hiku/~/fte_2/~/linear_1/b -params/ftes_2/xvec_proj/layers_0/kernel -> alpha_net_hiku/~/fte_2/~/linear_1/w -params/ftes_2/xvec_proj/layers_2/bias -> alpha_net_hiku/~/fte_2/~/linear_2/b -params/ftes_2/xvec_proj/layers_2/kernel -> alpha_net_hiku/~/fte_2/~/linear_2/w -params/ftes_3/vec_proj/kernel -> alpha_net_hiku/~/fte_3/~/linear/w -params/ftes_3/xvec_proj/layers_0/bias -> alpha_net_hiku/~/fte_3/~/linear_1/b -params/ftes_3/xvec_proj/layers_0/kernel -> alpha_net_hiku/~/fte_3/~/linear_1/w -params/ftes_3/xvec_proj/layers_2/bias -> alpha_net_hiku/~/fte_3/~/linear_2/b -params/ftes_3/xvec_proj/layers_2/kernel -> alpha_net_hiku/~/fte_3/~/linear_2/w -params/kernel1 -> alpha_net_hiku/kernel1 -params/kernels_imag -> alpha_net_hiku/kernels_imag -params/kernels_real -> alpha_net_hiku/kernels_real -params/last_layer/bias -> alpha_net_hiku/~/linear_4/b -params/last_layer/kernel -> alpha_net_hiku/~/linear_4/w -params/last_layer_quantum/bias -> alpha_net_hiku/~/linear_5/b -params/last_layer_quantum/kernel -> alpha_net_hiku/~/linear_5/w -params/lin/layers_0/bias -> alpha_net_hiku/~/linear_2/b -params/lin/layers_0/kernel -> alpha_net_hiku/~/linear_2/w -params/lin/layers_2/bias -> alpha_net_hiku/~/linear_3/b -params/lin/layers_2/kernel -> alpha_net_hiku/~/linear_3/w -params/message_layers_0/dia/bias -> alpha_net_hiku/~/equi_message_passing/~message/linear_1/b -params/message_layers_0/dia/kernel -> alpha_net_hiku/~/equi_message_passing/~message/linear_1/w -params/message_layers_0/diachi1 -> alpha_net_hiku/~/equi_message_passing/diachi1 -params/message_layers_0/diagonal/layers_0/bias -> alpha_net_hiku/~/equi_message_passing/~_build_diagonal/linear/b -params/message_layers_0/diagonal/layers_0/kernel -> alpha_net_hiku/~/equi_message_passing/~_build_diagonal/linear/w -params/message_layers_0/diagonal/layers_2/bias -> alpha_net_hiku/~/equi_message_passing/~_build_diagonal/linear_1/b -params/message_layers_0/diagonal/layers_2/kernel -> alpha_net_hiku/~/equi_message_passing/~_build_diagonal/linear_1/w -params/message_layers_0/dir_proj/layers_0/bias -> alpha_net_hiku/~/equi_message_passing/~_build_dir_proj/linear/b -params/message_layers_0/dir_proj/layers_0/kernel -> alpha_net_hiku/~/equi_message_passing/~_build_dir_proj/linear/w -params/message_layers_0/dir_proj/layers_2/bias -> alpha_net_hiku/~/equi_message_passing/~_build_dir_proj/linear_1/b -params/message_layers_0/dir_proj/layers_2/kernel -> alpha_net_hiku/~/equi_message_passing/~_build_dir_proj/linear_1/w -params/message_layers_0/dx_layer_norm/bias -> alpha_net_hiku/~/equi_message_passing/layer_norm_1/offset -params/message_layers_0/dx_layer_norm/scale -> alpha_net_hiku/~/equi_message_passing/layer_norm_1/scale -params/message_layers_0/fc_mps/bias -> alpha_net_hiku/~/equi_message_passing/~/linear/b -params/message_layers_0/fc_mps/kernel -> alpha_net_hiku/~/equi_message_passing/~/linear/w -params/message_layers_0/kernel_imag -> alpha_net_hiku/~/equi_message_passing/kernel_imag -params/message_layers_0/kernel_real -> alpha_net_hiku/~/equi_message_passing/kernel_real -params/message_layers_0/rbf_proj/bias -> alpha_net_hiku/~/equi_message_passing/linear/b -params/message_layers_0/rbf_proj/kernel -> alpha_net_hiku/~/equi_message_passing/linear/w -params/message_layers_0/scale/bias -> alpha_net_hiku/~/equi_message_passing/~message/linear/b -params/message_layers_0/scale/kernel -> alpha_net_hiku/~/equi_message_passing/~message/linear/w -params/message_layers_0/scale2/bias -> alpha_net_hiku/~/equi_message_passing/linear_1/b -params/message_layers_0/scale2/kernel -> alpha_net_hiku/~/equi_message_passing/linear_1/w -params/message_layers_0/x_layernorm/bias -> alpha_net_hiku/~/equi_message_passing/layer_norm/offset -params/message_layers_0/x_layernorm/scale -> alpha_net_hiku/~/equi_message_passing/layer_norm/scale -params/message_layers_0/x_proj/layers_0/bias -> alpha_net_hiku/~/equi_message_passing/~_build_x_proj/linear/b -params/message_layers_0/x_proj/layers_0/kernel -> alpha_net_hiku/~/equi_message_passing/~_build_x_proj/linear/w -params/message_layers_0/x_proj/layers_2/bias -> alpha_net_hiku/~/equi_message_passing/~_build_x_proj/linear_1/b -params/message_layers_0/x_proj/layers_2/kernel -> alpha_net_hiku/~/equi_message_passing/~_build_x_proj/linear_1/w -params/message_layers_1/dia/bias -> alpha_net_hiku/~/equi_message_passing_1/~message/linear_1/b -params/message_layers_1/dia/kernel -> alpha_net_hiku/~/equi_message_passing_1/~message/linear_1/w -params/message_layers_1/diachi1 -> alpha_net_hiku/~/equi_message_passing_1/diachi1 -params/message_layers_1/diagonal/layers_0/bias -> alpha_net_hiku/~/equi_message_passing_1/~_build_diagonal/linear/b -params/message_layers_1/diagonal/layers_0/kernel -> alpha_net_hiku/~/equi_message_passing_1/~_build_diagonal/linear/w -params/message_layers_1/diagonal/layers_2/bias -> alpha_net_hiku/~/equi_message_passing_1/~_build_diagonal/linear_1/b -params/message_layers_1/diagonal/layers_2/kernel -> alpha_net_hiku/~/equi_message_passing_1/~_build_diagonal/linear_1/w -params/message_layers_1/dir_proj/layers_0/bias -> alpha_net_hiku/~/equi_message_passing_1/~_build_dir_proj/linear/b -params/message_layers_1/dir_proj/layers_0/kernel -> alpha_net_hiku/~/equi_message_passing_1/~_build_dir_proj/linear/w -params/message_layers_1/dir_proj/layers_2/bias -> alpha_net_hiku/~/equi_message_passing_1/~_build_dir_proj/linear_1/b -params/message_layers_1/dir_proj/layers_2/kernel -> alpha_net_hiku/~/equi_message_passing_1/~_build_dir_proj/linear_1/w -params/message_layers_1/dx_layer_norm/bias -> alpha_net_hiku/~/equi_message_passing_1/layer_norm_1/offset -params/message_layers_1/dx_layer_norm/scale -> alpha_net_hiku/~/equi_message_passing_1/layer_norm_1/scale -params/message_layers_1/fc_mps/bias -> alpha_net_hiku/~/equi_message_passing_1/~/linear/b -params/message_layers_1/fc_mps/kernel -> alpha_net_hiku/~/equi_message_passing_1/~/linear/w -params/message_layers_1/kernel_imag -> alpha_net_hiku/~/equi_message_passing_1/kernel_imag -params/message_layers_1/kernel_real -> alpha_net_hiku/~/equi_message_passing_1/kernel_real -params/message_layers_1/rbf_proj/bias -> alpha_net_hiku/~/equi_message_passing_1/linear/b -params/message_layers_1/rbf_proj/kernel -> alpha_net_hiku/~/equi_message_passing_1/linear/w -params/message_layers_1/scale/bias -> alpha_net_hiku/~/equi_message_passing_1/~message/linear/b -params/message_layers_1/scale/kernel -> alpha_net_hiku/~/equi_message_passing_1/~message/linear/w -params/message_layers_1/scale2/bias -> alpha_net_hiku/~/equi_message_passing_1/linear_1/b -params/message_layers_1/scale2/kernel -> alpha_net_hiku/~/equi_message_passing_1/linear_1/w -params/message_layers_1/x_layernorm/bias -> alpha_net_hiku/~/equi_message_passing_1/layer_norm/offset -params/message_layers_1/x_layernorm/scale -> alpha_net_hiku/~/equi_message_passing_1/layer_norm/scale -params/message_layers_1/x_proj/layers_0/bias -> alpha_net_hiku/~/equi_message_passing_1/~_build_x_proj/linear/b -params/message_layers_1/x_proj/layers_0/kernel -> alpha_net_hiku/~/equi_message_passing_1/~_build_x_proj/linear/w -params/message_layers_1/x_proj/layers_2/bias -> alpha_net_hiku/~/equi_message_passing_1/~_build_x_proj/linear_1/b -params/message_layers_1/x_proj/layers_2/kernel -> alpha_net_hiku/~/equi_message_passing_1/~_build_x_proj/linear_1/w -params/message_layers_2/dia/bias -> alpha_net_hiku/~/equi_message_passing_2/~message/linear_1/b -params/message_layers_2/dia/kernel -> alpha_net_hiku/~/equi_message_passing_2/~message/linear_1/w -params/message_layers_2/diachi1 -> alpha_net_hiku/~/equi_message_passing_2/diachi1 -params/message_layers_2/diagonal/layers_0/bias -> alpha_net_hiku/~/equi_message_passing_2/~_build_diagonal/linear/b -params/message_layers_2/diagonal/layers_0/kernel -> alpha_net_hiku/~/equi_message_passing_2/~_build_diagonal/linear/w -params/message_layers_2/diagonal/layers_2/bias -> alpha_net_hiku/~/equi_message_passing_2/~_build_diagonal/linear_1/b -params/message_layers_2/diagonal/layers_2/kernel -> alpha_net_hiku/~/equi_message_passing_2/~_build_diagonal/linear_1/w -params/message_layers_2/dir_proj/layers_0/bias -> alpha_net_hiku/~/equi_message_passing_2/~_build_dir_proj/linear/b -params/message_layers_2/dir_proj/layers_0/kernel -> alpha_net_hiku/~/equi_message_passing_2/~_build_dir_proj/linear/w -params/message_layers_2/dir_proj/layers_2/bias -> alpha_net_hiku/~/equi_message_passing_2/~_build_dir_proj/linear_1/b -params/message_layers_2/dir_proj/layers_2/kernel -> alpha_net_hiku/~/equi_message_passing_2/~_build_dir_proj/linear_1/w -params/message_layers_2/dx_layer_norm/bias -> alpha_net_hiku/~/equi_message_passing_2/layer_norm_1/offset -params/message_layers_2/dx_layer_norm/scale -> alpha_net_hiku/~/equi_message_passing_2/layer_norm_1/scale -params/message_layers_2/fc_mps/bias -> alpha_net_hiku/~/equi_message_passing_2/~/linear/b -params/message_layers_2/fc_mps/kernel -> alpha_net_hiku/~/equi_message_passing_2/~/linear/w -params/message_layers_2/kernel_imag -> alpha_net_hiku/~/equi_message_passing_2/kernel_imag -params/message_layers_2/kernel_real -> alpha_net_hiku/~/equi_message_passing_2/kernel_real -params/message_layers_2/rbf_proj/bias -> alpha_net_hiku/~/equi_message_passing_2/linear/b -params/message_layers_2/rbf_proj/kernel -> alpha_net_hiku/~/equi_message_passing_2/linear/w -params/message_layers_2/scale/bias -> alpha_net_hiku/~/equi_message_passing_2/~message/linear/b -params/message_layers_2/scale/kernel -> alpha_net_hiku/~/equi_message_passing_2/~message/linear/w -params/message_layers_2/scale2/bias -> alpha_net_hiku/~/equi_message_passing_2/linear_1/b -params/message_layers_2/scale2/kernel -> alpha_net_hiku/~/equi_message_passing_2/linear_1/w -params/message_layers_2/x_layernorm/bias -> alpha_net_hiku/~/equi_message_passing_2/layer_norm/offset -params/message_layers_2/x_layernorm/scale -> alpha_net_hiku/~/equi_message_passing_2/layer_norm/scale -params/message_layers_2/x_proj/layers_0/bias -> alpha_net_hiku/~/equi_message_passing_2/~_build_x_proj/linear/b -params/message_layers_2/x_proj/layers_0/kernel -> alpha_net_hiku/~/equi_message_passing_2/~_build_x_proj/linear/w -params/message_layers_2/x_proj/layers_2/bias -> alpha_net_hiku/~/equi_message_passing_2/~_build_x_proj/linear_1/b -params/message_layers_2/x_proj/layers_2/kernel -> alpha_net_hiku/~/equi_message_passing_2/~_build_x_proj/linear_1/w -params/message_layers_3/dia/bias -> alpha_net_hiku/~/equi_message_passing_3/~message/linear_1/b -params/message_layers_3/dia/kernel -> alpha_net_hiku/~/equi_message_passing_3/~message/linear_1/w -params/message_layers_3/diachi1 -> alpha_net_hiku/~/equi_message_passing_3/diachi1 -params/message_layers_3/diagonal/layers_0/bias -> alpha_net_hiku/~/equi_message_passing_3/~_build_diagonal/linear/b -params/message_layers_3/diagonal/layers_0/kernel -> alpha_net_hiku/~/equi_message_passing_3/~_build_diagonal/linear/w -params/message_layers_3/diagonal/layers_2/bias -> alpha_net_hiku/~/equi_message_passing_3/~_build_diagonal/linear_1/b -params/message_layers_3/diagonal/layers_2/kernel -> alpha_net_hiku/~/equi_message_passing_3/~_build_diagonal/linear_1/w -params/message_layers_3/dir_proj/layers_0/bias -> alpha_net_hiku/~/equi_message_passing_3/~_build_dir_proj/linear/b -params/message_layers_3/dir_proj/layers_0/kernel -> alpha_net_hiku/~/equi_message_passing_3/~_build_dir_proj/linear/w -params/message_layers_3/dir_proj/layers_2/bias -> alpha_net_hiku/~/equi_message_passing_3/~_build_dir_proj/linear_1/b -params/message_layers_3/dir_proj/layers_2/kernel -> alpha_net_hiku/~/equi_message_passing_3/~_build_dir_proj/linear_1/w -params/message_layers_3/dx_layer_norm/bias -> alpha_net_hiku/~/equi_message_passing_3/layer_norm_1/offset -params/message_layers_3/dx_layer_norm/scale -> alpha_net_hiku/~/equi_message_passing_3/layer_norm_1/scale -params/message_layers_3/fc_mps/bias -> alpha_net_hiku/~/equi_message_passing_3/~/linear/b -params/message_layers_3/fc_mps/kernel -> alpha_net_hiku/~/equi_message_passing_3/~/linear/w -params/message_layers_3/kernel_imag -> alpha_net_hiku/~/equi_message_passing_3/kernel_imag -params/message_layers_3/kernel_real -> alpha_net_hiku/~/equi_message_passing_3/kernel_real -params/message_layers_3/rbf_proj/bias -> alpha_net_hiku/~/equi_message_passing_3/linear/b -params/message_layers_3/rbf_proj/kernel -> alpha_net_hiku/~/equi_message_passing_3/linear/w -params/message_layers_3/scale/bias -> alpha_net_hiku/~/equi_message_passing_3/~message/linear/b -params/message_layers_3/scale/kernel -> alpha_net_hiku/~/equi_message_passing_3/~message/linear/w -params/message_layers_3/scale2/bias -> alpha_net_hiku/~/equi_message_passing_3/linear_1/b -params/message_layers_3/scale2/kernel -> alpha_net_hiku/~/equi_message_passing_3/linear_1/w -params/message_layers_3/x_layernorm/bias -> alpha_net_hiku/~/equi_message_passing_3/layer_norm/offset -params/message_layers_3/x_layernorm/scale -> alpha_net_hiku/~/equi_message_passing_3/layer_norm/scale -params/message_layers_3/x_proj/layers_0/bias -> alpha_net_hiku/~/equi_message_passing_3/~_build_x_proj/linear/b -params/message_layers_3/x_proj/layers_0/kernel -> alpha_net_hiku/~/equi_message_passing_3/~_build_x_proj/linear/w -params/message_layers_3/x_proj/layers_2/bias -> alpha_net_hiku/~/equi_message_passing_3/~_build_x_proj/linear_1/b -params/message_layers_3/x_proj/layers_2/kernel -> alpha_net_hiku/~/equi_message_passing_3/~_build_x_proj/linear_1/w -params/neighbor_emb/Embed_0/embedding -> alpha_net_hiku/~/neighbor_emb/embed/embeddings -params/radial_emb/bessel_weights -> alpha_net_hiku/~/rbf_emb/bessel_weights -params/radial_lin/layers_0/bias -> alpha_net_hiku/~/linear/b -params/radial_lin/layers_0/kernel -> alpha_net_hiku/~/linear/w -params/radial_lin/layers_2/bias -> alpha_net_hiku/~/linear_1/b -params/radial_lin/layers_2/kernel -> alpha_net_hiku/~/linear_1/w -params/s_vector/Dense_0/bias -> alpha_net_hiku/~/s_vector/linear/b -params/s_vector/Dense_0/kernel -> alpha_net_hiku/~/s_vector/linear/w -params/z_emb/embedding -> alpha_net_hiku/~/embed/embeddings diff --git a/pretrained/MATPES/README.md b/pretrained/MATPES/README.md deleted file mode 100644 index a25b077..0000000 --- a/pretrained/MATPES/README.md +++ /dev/null @@ -1,17 +0,0 @@ -# AlphaNet-MATPES-r2scan - -A model trained on the [MATPES](https://matpes.ai/) R2SCAN level dataset. - -## Model Details - -* **Parameters:** Approximately 16M - -## Access the Model - -The following resources are available in the `pretrained_models/MATPES` path: - -* **Model Configuration:** `matpes.json` -* **Model state\_dict:** Pre-trained weights `r2scan_1021.ckpt` - -## Performance -Remain to be evaluated \ No newline at end of file diff --git a/pretrained/MATPES/haiku_model_converted/conversion_map.txt b/pretrained/MATPES/haiku_model_converted/conversion_map.txt deleted file mode 100644 index 1371d3a..0000000 --- a/pretrained/MATPES/haiku_model_converted/conversion_map.txt +++ /dev/null @@ -1,160 +0,0 @@ -Flax Key -> Haiku Key Mapping --------------------------------------------------------------------------------- -params/a -> alpha_net_hiku/a -params/b -> alpha_net_hiku/b -params/ftes_0/vec_proj/kernel -> alpha_net_hiku/~/fte/~/linear/w -params/ftes_0/xvec_proj/layers_0/bias -> alpha_net_hiku/~/fte/~/linear_1/b -params/ftes_0/xvec_proj/layers_0/kernel -> alpha_net_hiku/~/fte/~/linear_1/w -params/ftes_0/xvec_proj/layers_2/bias -> alpha_net_hiku/~/fte/~/linear_2/b -params/ftes_0/xvec_proj/layers_2/kernel -> alpha_net_hiku/~/fte/~/linear_2/w -params/ftes_1/vec_proj/kernel -> alpha_net_hiku/~/fte_1/~/linear/w -params/ftes_1/xvec_proj/layers_0/bias -> alpha_net_hiku/~/fte_1/~/linear_1/b -params/ftes_1/xvec_proj/layers_0/kernel -> alpha_net_hiku/~/fte_1/~/linear_1/w -params/ftes_1/xvec_proj/layers_2/bias -> alpha_net_hiku/~/fte_1/~/linear_2/b -params/ftes_1/xvec_proj/layers_2/kernel -> alpha_net_hiku/~/fte_1/~/linear_2/w -params/ftes_2/vec_proj/kernel -> alpha_net_hiku/~/fte_2/~/linear/w -params/ftes_2/xvec_proj/layers_0/bias -> alpha_net_hiku/~/fte_2/~/linear_1/b -params/ftes_2/xvec_proj/layers_0/kernel -> alpha_net_hiku/~/fte_2/~/linear_1/w -params/ftes_2/xvec_proj/layers_2/bias -> alpha_net_hiku/~/fte_2/~/linear_2/b -params/ftes_2/xvec_proj/layers_2/kernel -> alpha_net_hiku/~/fte_2/~/linear_2/w -params/ftes_3/vec_proj/kernel -> alpha_net_hiku/~/fte_3/~/linear/w -params/ftes_3/xvec_proj/layers_0/bias -> alpha_net_hiku/~/fte_3/~/linear_1/b -params/ftes_3/xvec_proj/layers_0/kernel -> alpha_net_hiku/~/fte_3/~/linear_1/w -params/ftes_3/xvec_proj/layers_2/bias -> alpha_net_hiku/~/fte_3/~/linear_2/b -params/ftes_3/xvec_proj/layers_2/kernel -> alpha_net_hiku/~/fte_3/~/linear_2/w -params/kernel1 -> alpha_net_hiku/kernel1 -params/kernels_imag -> alpha_net_hiku/kernels_imag -params/kernels_real -> alpha_net_hiku/kernels_real -params/last_layer/bias -> alpha_net_hiku/~/linear_4/b -params/last_layer/kernel -> alpha_net_hiku/~/linear_4/w -params/last_layer_quantum/bias -> alpha_net_hiku/~/linear_5/b -params/last_layer_quantum/kernel -> alpha_net_hiku/~/linear_5/w -params/lin/layers_0/bias -> alpha_net_hiku/~/linear_2/b -params/lin/layers_0/kernel -> alpha_net_hiku/~/linear_2/w -params/lin/layers_2/bias -> alpha_net_hiku/~/linear_3/b -params/lin/layers_2/kernel -> alpha_net_hiku/~/linear_3/w -params/message_layers_0/dia/bias -> alpha_net_hiku/~/equi_message_passing/~message/linear_1/b -params/message_layers_0/dia/kernel -> alpha_net_hiku/~/equi_message_passing/~message/linear_1/w -params/message_layers_0/diachi1 -> alpha_net_hiku/~/equi_message_passing/diachi1 -params/message_layers_0/diagonal/layers_0/bias -> alpha_net_hiku/~/equi_message_passing/~_build_diagonal/linear/b -params/message_layers_0/diagonal/layers_0/kernel -> alpha_net_hiku/~/equi_message_passing/~_build_diagonal/linear/w -params/message_layers_0/diagonal/layers_2/bias -> alpha_net_hiku/~/equi_message_passing/~_build_diagonal/linear_1/b -params/message_layers_0/diagonal/layers_2/kernel -> alpha_net_hiku/~/equi_message_passing/~_build_diagonal/linear_1/w -params/message_layers_0/dir_proj/layers_0/bias -> alpha_net_hiku/~/equi_message_passing/~_build_dir_proj/linear/b -params/message_layers_0/dir_proj/layers_0/kernel -> alpha_net_hiku/~/equi_message_passing/~_build_dir_proj/linear/w -params/message_layers_0/dir_proj/layers_2/bias -> alpha_net_hiku/~/equi_message_passing/~_build_dir_proj/linear_1/b -params/message_layers_0/dir_proj/layers_2/kernel -> alpha_net_hiku/~/equi_message_passing/~_build_dir_proj/linear_1/w -params/message_layers_0/dx_layer_norm/bias -> alpha_net_hiku/~/equi_message_passing/layer_norm_1/offset -params/message_layers_0/dx_layer_norm/scale -> alpha_net_hiku/~/equi_message_passing/layer_norm_1/scale -params/message_layers_0/fc_mps/bias -> alpha_net_hiku/~/equi_message_passing/~/linear/b -params/message_layers_0/fc_mps/kernel -> alpha_net_hiku/~/equi_message_passing/~/linear/w -params/message_layers_0/kernel_imag -> alpha_net_hiku/~/equi_message_passing/kernel_imag -params/message_layers_0/kernel_real -> alpha_net_hiku/~/equi_message_passing/kernel_real -params/message_layers_0/rbf_proj/bias -> alpha_net_hiku/~/equi_message_passing/linear/b -params/message_layers_0/rbf_proj/kernel -> alpha_net_hiku/~/equi_message_passing/linear/w -params/message_layers_0/scale/bias -> alpha_net_hiku/~/equi_message_passing/~message/linear/b -params/message_layers_0/scale/kernel -> alpha_net_hiku/~/equi_message_passing/~message/linear/w -params/message_layers_0/scale2/bias -> alpha_net_hiku/~/equi_message_passing/linear_1/b -params/message_layers_0/scale2/kernel -> alpha_net_hiku/~/equi_message_passing/linear_1/w -params/message_layers_0/x_layernorm/bias -> alpha_net_hiku/~/equi_message_passing/layer_norm/offset -params/message_layers_0/x_layernorm/scale -> alpha_net_hiku/~/equi_message_passing/layer_norm/scale -params/message_layers_0/x_proj/layers_0/bias -> alpha_net_hiku/~/equi_message_passing/~_build_x_proj/linear/b -params/message_layers_0/x_proj/layers_0/kernel -> alpha_net_hiku/~/equi_message_passing/~_build_x_proj/linear/w -params/message_layers_0/x_proj/layers_2/bias -> alpha_net_hiku/~/equi_message_passing/~_build_x_proj/linear_1/b -params/message_layers_0/x_proj/layers_2/kernel -> alpha_net_hiku/~/equi_message_passing/~_build_x_proj/linear_1/w -params/message_layers_1/dia/bias -> alpha_net_hiku/~/equi_message_passing_1/~message/linear_1/b -params/message_layers_1/dia/kernel -> alpha_net_hiku/~/equi_message_passing_1/~message/linear_1/w -params/message_layers_1/diachi1 -> alpha_net_hiku/~/equi_message_passing_1/diachi1 -params/message_layers_1/diagonal/layers_0/bias -> alpha_net_hiku/~/equi_message_passing_1/~_build_diagonal/linear/b -params/message_layers_1/diagonal/layers_0/kernel -> alpha_net_hiku/~/equi_message_passing_1/~_build_diagonal/linear/w -params/message_layers_1/diagonal/layers_2/bias -> alpha_net_hiku/~/equi_message_passing_1/~_build_diagonal/linear_1/b -params/message_layers_1/diagonal/layers_2/kernel -> alpha_net_hiku/~/equi_message_passing_1/~_build_diagonal/linear_1/w -params/message_layers_1/dir_proj/layers_0/bias -> alpha_net_hiku/~/equi_message_passing_1/~_build_dir_proj/linear/b -params/message_layers_1/dir_proj/layers_0/kernel -> alpha_net_hiku/~/equi_message_passing_1/~_build_dir_proj/linear/w -params/message_layers_1/dir_proj/layers_2/bias -> alpha_net_hiku/~/equi_message_passing_1/~_build_dir_proj/linear_1/b -params/message_layers_1/dir_proj/layers_2/kernel -> alpha_net_hiku/~/equi_message_passing_1/~_build_dir_proj/linear_1/w -params/message_layers_1/dx_layer_norm/bias -> alpha_net_hiku/~/equi_message_passing_1/layer_norm_1/offset -params/message_layers_1/dx_layer_norm/scale -> alpha_net_hiku/~/equi_message_passing_1/layer_norm_1/scale -params/message_layers_1/fc_mps/bias -> alpha_net_hiku/~/equi_message_passing_1/~/linear/b -params/message_layers_1/fc_mps/kernel -> alpha_net_hiku/~/equi_message_passing_1/~/linear/w -params/message_layers_1/kernel_imag -> alpha_net_hiku/~/equi_message_passing_1/kernel_imag -params/message_layers_1/kernel_real -> alpha_net_hiku/~/equi_message_passing_1/kernel_real -params/message_layers_1/rbf_proj/bias -> alpha_net_hiku/~/equi_message_passing_1/linear/b -params/message_layers_1/rbf_proj/kernel -> alpha_net_hiku/~/equi_message_passing_1/linear/w -params/message_layers_1/scale/bias -> alpha_net_hiku/~/equi_message_passing_1/~message/linear/b -params/message_layers_1/scale/kernel -> alpha_net_hiku/~/equi_message_passing_1/~message/linear/w -params/message_layers_1/scale2/bias -> alpha_net_hiku/~/equi_message_passing_1/linear_1/b -params/message_layers_1/scale2/kernel -> alpha_net_hiku/~/equi_message_passing_1/linear_1/w -params/message_layers_1/x_layernorm/bias -> alpha_net_hiku/~/equi_message_passing_1/layer_norm/offset -params/message_layers_1/x_layernorm/scale -> alpha_net_hiku/~/equi_message_passing_1/layer_norm/scale -params/message_layers_1/x_proj/layers_0/bias -> alpha_net_hiku/~/equi_message_passing_1/~_build_x_proj/linear/b -params/message_layers_1/x_proj/layers_0/kernel -> alpha_net_hiku/~/equi_message_passing_1/~_build_x_proj/linear/w -params/message_layers_1/x_proj/layers_2/bias -> alpha_net_hiku/~/equi_message_passing_1/~_build_x_proj/linear_1/b -params/message_layers_1/x_proj/layers_2/kernel -> alpha_net_hiku/~/equi_message_passing_1/~_build_x_proj/linear_1/w -params/message_layers_2/dia/bias -> alpha_net_hiku/~/equi_message_passing_2/~message/linear_1/b -params/message_layers_2/dia/kernel -> alpha_net_hiku/~/equi_message_passing_2/~message/linear_1/w -params/message_layers_2/diachi1 -> alpha_net_hiku/~/equi_message_passing_2/diachi1 -params/message_layers_2/diagonal/layers_0/bias -> alpha_net_hiku/~/equi_message_passing_2/~_build_diagonal/linear/b -params/message_layers_2/diagonal/layers_0/kernel -> alpha_net_hiku/~/equi_message_passing_2/~_build_diagonal/linear/w -params/message_layers_2/diagonal/layers_2/bias -> alpha_net_hiku/~/equi_message_passing_2/~_build_diagonal/linear_1/b -params/message_layers_2/diagonal/layers_2/kernel -> alpha_net_hiku/~/equi_message_passing_2/~_build_diagonal/linear_1/w -params/message_layers_2/dir_proj/layers_0/bias -> alpha_net_hiku/~/equi_message_passing_2/~_build_dir_proj/linear/b -params/message_layers_2/dir_proj/layers_0/kernel -> alpha_net_hiku/~/equi_message_passing_2/~_build_dir_proj/linear/w -params/message_layers_2/dir_proj/layers_2/bias -> alpha_net_hiku/~/equi_message_passing_2/~_build_dir_proj/linear_1/b -params/message_layers_2/dir_proj/layers_2/kernel -> alpha_net_hiku/~/equi_message_passing_2/~_build_dir_proj/linear_1/w -params/message_layers_2/dx_layer_norm/bias -> alpha_net_hiku/~/equi_message_passing_2/layer_norm_1/offset -params/message_layers_2/dx_layer_norm/scale -> alpha_net_hiku/~/equi_message_passing_2/layer_norm_1/scale -params/message_layers_2/fc_mps/bias -> alpha_net_hiku/~/equi_message_passing_2/~/linear/b -params/message_layers_2/fc_mps/kernel -> alpha_net_hiku/~/equi_message_passing_2/~/linear/w -params/message_layers_2/kernel_imag -> alpha_net_hiku/~/equi_message_passing_2/kernel_imag -params/message_layers_2/kernel_real -> alpha_net_hiku/~/equi_message_passing_2/kernel_real -params/message_layers_2/rbf_proj/bias -> alpha_net_hiku/~/equi_message_passing_2/linear/b -params/message_layers_2/rbf_proj/kernel -> alpha_net_hiku/~/equi_message_passing_2/linear/w -params/message_layers_2/scale/bias -> alpha_net_hiku/~/equi_message_passing_2/~message/linear/b -params/message_layers_2/scale/kernel -> alpha_net_hiku/~/equi_message_passing_2/~message/linear/w -params/message_layers_2/scale2/bias -> alpha_net_hiku/~/equi_message_passing_2/linear_1/b -params/message_layers_2/scale2/kernel -> alpha_net_hiku/~/equi_message_passing_2/linear_1/w -params/message_layers_2/x_layernorm/bias -> alpha_net_hiku/~/equi_message_passing_2/layer_norm/offset -params/message_layers_2/x_layernorm/scale -> alpha_net_hiku/~/equi_message_passing_2/layer_norm/scale -params/message_layers_2/x_proj/layers_0/bias -> alpha_net_hiku/~/equi_message_passing_2/~_build_x_proj/linear/b -params/message_layers_2/x_proj/layers_0/kernel -> alpha_net_hiku/~/equi_message_passing_2/~_build_x_proj/linear/w -params/message_layers_2/x_proj/layers_2/bias -> alpha_net_hiku/~/equi_message_passing_2/~_build_x_proj/linear_1/b -params/message_layers_2/x_proj/layers_2/kernel -> alpha_net_hiku/~/equi_message_passing_2/~_build_x_proj/linear_1/w -params/message_layers_3/dia/bias -> alpha_net_hiku/~/equi_message_passing_3/~message/linear_1/b -params/message_layers_3/dia/kernel -> alpha_net_hiku/~/equi_message_passing_3/~message/linear_1/w -params/message_layers_3/diachi1 -> alpha_net_hiku/~/equi_message_passing_3/diachi1 -params/message_layers_3/diagonal/layers_0/bias -> alpha_net_hiku/~/equi_message_passing_3/~_build_diagonal/linear/b -params/message_layers_3/diagonal/layers_0/kernel -> alpha_net_hiku/~/equi_message_passing_3/~_build_diagonal/linear/w -params/message_layers_3/diagonal/layers_2/bias -> alpha_net_hiku/~/equi_message_passing_3/~_build_diagonal/linear_1/b -params/message_layers_3/diagonal/layers_2/kernel -> alpha_net_hiku/~/equi_message_passing_3/~_build_diagonal/linear_1/w -params/message_layers_3/dir_proj/layers_0/bias -> alpha_net_hiku/~/equi_message_passing_3/~_build_dir_proj/linear/b -params/message_layers_3/dir_proj/layers_0/kernel -> alpha_net_hiku/~/equi_message_passing_3/~_build_dir_proj/linear/w -params/message_layers_3/dir_proj/layers_2/bias -> alpha_net_hiku/~/equi_message_passing_3/~_build_dir_proj/linear_1/b -params/message_layers_3/dir_proj/layers_2/kernel -> alpha_net_hiku/~/equi_message_passing_3/~_build_dir_proj/linear_1/w -params/message_layers_3/dx_layer_norm/bias -> alpha_net_hiku/~/equi_message_passing_3/layer_norm_1/offset -params/message_layers_3/dx_layer_norm/scale -> alpha_net_hiku/~/equi_message_passing_3/layer_norm_1/scale -params/message_layers_3/fc_mps/bias -> alpha_net_hiku/~/equi_message_passing_3/~/linear/b -params/message_layers_3/fc_mps/kernel -> alpha_net_hiku/~/equi_message_passing_3/~/linear/w -params/message_layers_3/kernel_imag -> alpha_net_hiku/~/equi_message_passing_3/kernel_imag -params/message_layers_3/kernel_real -> alpha_net_hiku/~/equi_message_passing_3/kernel_real -params/message_layers_3/rbf_proj/bias -> alpha_net_hiku/~/equi_message_passing_3/linear/b -params/message_layers_3/rbf_proj/kernel -> alpha_net_hiku/~/equi_message_passing_3/linear/w -params/message_layers_3/scale/bias -> alpha_net_hiku/~/equi_message_passing_3/~message/linear/b -params/message_layers_3/scale/kernel -> alpha_net_hiku/~/equi_message_passing_3/~message/linear/w -params/message_layers_3/scale2/bias -> alpha_net_hiku/~/equi_message_passing_3/linear_1/b -params/message_layers_3/scale2/kernel -> alpha_net_hiku/~/equi_message_passing_3/linear_1/w -params/message_layers_3/x_layernorm/bias -> alpha_net_hiku/~/equi_message_passing_3/layer_norm/offset -params/message_layers_3/x_layernorm/scale -> alpha_net_hiku/~/equi_message_passing_3/layer_norm/scale -params/message_layers_3/x_proj/layers_0/bias -> alpha_net_hiku/~/equi_message_passing_3/~_build_x_proj/linear/b -params/message_layers_3/x_proj/layers_0/kernel -> alpha_net_hiku/~/equi_message_passing_3/~_build_x_proj/linear/w -params/message_layers_3/x_proj/layers_2/bias -> alpha_net_hiku/~/equi_message_passing_3/~_build_x_proj/linear_1/b -params/message_layers_3/x_proj/layers_2/kernel -> alpha_net_hiku/~/equi_message_passing_3/~_build_x_proj/linear_1/w -params/neighbor_emb/Embed_0/embedding -> alpha_net_hiku/~/neighbor_emb/embed/embeddings -params/radial_emb/bessel_weights -> alpha_net_hiku/~/rbf_emb/bessel_weights -params/radial_lin/layers_0/bias -> alpha_net_hiku/~/linear/b -params/radial_lin/layers_0/kernel -> alpha_net_hiku/~/linear/w -params/radial_lin/layers_2/bias -> alpha_net_hiku/~/linear_1/b -params/radial_lin/layers_2/kernel -> alpha_net_hiku/~/linear_1/w -params/s_vector/Dense_0/bias -> alpha_net_hiku/~/s_vector/linear/b -params/s_vector/Dense_0/kernel -> alpha_net_hiku/~/s_vector/linear/w -params/z_emb/embedding -> alpha_net_hiku/~/embed/embeddings diff --git a/pretrained/MATPES/haiku_model_converted/haiku_model.pkl b/pretrained/MATPES/haiku_model_converted/haiku_model.pkl deleted file mode 100644 index 5bfdbe1..0000000 Binary files a/pretrained/MATPES/haiku_model_converted/haiku_model.pkl and /dev/null differ diff --git a/pretrained/MATPES/matpes.json b/pretrained/MATPES/matpes.json deleted file mode 100644 index 2618eeb..0000000 --- a/pretrained/MATPES/matpes.json +++ /dev/null @@ -1,53 +0,0 @@ -{ - "data": { - "root": "dataset/", - "dataset_name": "MP-TRAIN", - "target": "energy_force", - "train_dataset": "r2scan-train", - "valid_dataset": "r2scan-val", - - "test_dataset": "r2scan-val", - "seed": 42 - }, - "model": { - "name": "Alphanet", - "num_layers": 4, - "num_targets": 1, - "output_dim": 1, - "compute_forces": true, - "compute_stress": true, - "use_pbc": true, - "has_dropout_flag": false, - "hidden_channels": 176, - "main_chi1": 96, - "mp_chi1": 64, - "chi2": 24, - "cutoff": 5, - "num_radial": 8, - "zbl": true - }, - "train": { - "epochs": 200, - "batch_size": 16, - "accumulation_steps": 1, - "vt_batch_size": 16, - "lr": 0.0005, - "lr_decay_factor": 0.8, - "lr_decay_step_size": 80, - "weight_decay": 0, - "save_dir": "r2scan", - "log_dir": "", - "disable_tqdm": false, - "scheduler": "cosineannealinglr", - "norm_label": false, - "device": "cuda", - "energy_loss": "mae", - "force_loss": "mae", - "energy_metric": "mae", - "force_metric": "mae", - "energy_coef": 4.0, - "stress_coef": 10, - "force_coef": 100.0, - "eval_steps": 1 - } -} diff --git a/pretrained/MATPES/r2scan_1021.ckpt b/pretrained/MATPES/r2scan_1021.ckpt deleted file mode 100644 index 686c1c5..0000000 Binary files a/pretrained/MATPES/r2scan_1021.ckpt and /dev/null differ diff --git a/pretrained/MPtrj/README.md b/pretrained/MPtrj/README.md deleted file mode 100644 index c1749b2..0000000 --- a/pretrained/MPtrj/README.md +++ /dev/null @@ -1,37 +0,0 @@ -# AlphaNet-MPtrj-v1 - -A model trained on the MpTrj dataset. - -## Model Details - -* **Parameters:** Approximately 4.5 million - -## Access the Model - -The following resources are available in the `pretrained_models/MPtrj` path: - -* **Model Configuration:** `mp.json` -* **Model state\_dict:** Pre-trained weights can be downloaded from [Figshare](https://ndownloader.figshare.com/files/53851133). - -## Performance on WBM Test Set - -The detailed evaluation metrics for the model on the `full_test_set` are as follows: - -| Metric | Value | Unit/Description | -| :--- | :--- | :--- | -| F1 | 0.789 | fraction | -| DAF | 4.312 | dimensionless | -| Precision | 0.74 | fraction | -| Recall | 0.846 | fraction | -| Accuracy | 0.923 | fraction | -| TPR | 0.846 | fraction | -| FPR | 0.062 | fraction | -| TNR | 0.938 | fraction | -| FNR | 0.154 | fraction | -| TP | 37311.0 | count | -| FP | 13119.0 | count | -| TN | 199752.0 | count | -| FN | 6781.0 | count | -| MAE | 0.04 | eV/atom | -| RMSE | 0.091 | eV/atom | -| R2 | 0.747 | dimensionless | \ No newline at end of file diff --git a/pretrained/MPtrj/mp.json b/pretrained/MPtrj/mp.json deleted file mode 100644 index 3abf01f..0000000 --- a/pretrained/MPtrj/mp.json +++ /dev/null @@ -1,48 +0,0 @@ -{ - "data": { - "root": "dataset/", - "dataset_name": "MP-TRAIN", - "target": "energy_force", - "train_dataset": "MP_train", - "valid_dataset": "MP-train", - "test_dataset": "MP_val", - "seed": 42 - }, - "model": { - "name": "Alphanet", - "num_layers": 4, - "num_targets": 1, - "output_dim": 1, - "compute_forces": true, - "compute_stress": true, - "use_pbc": true, - "has_dropout_flag": false, - "hidden_channels": 176, - "cutoff": 5, - "num_radial": 8 - }, - "train": { - "epochs": 200, - "batch_size": 6, - "accumulation_steps": 1, - "vt_batch_size": 2, - "lr": 0.0005, - "lr_decay_factor": 0.8, - "lr_decay_step_size": 80, - "weight_decay": 0, - "save_dir": "wbm", - "log_dir": "", - "disable_tqdm": false, - "scheduler": "cosineannealinglr", - "norm_label": false, - "device": "cuda", - "energy_loss": "mae", - "force_loss": "mae", - "energy_metric": "mae", - "force_metric": "mae", - "energy_coef": 4.0, - "stress_coef": 10, - "force_coef": 100.0, - "eval_steps": 1 - } -} diff --git a/pretrained/OMA/README.md b/pretrained/OMA/README.md index 7c5ea71..c54c4b9 100644 --- a/pretrained/OMA/README.md +++ b/pretrained/OMA/README.md @@ -28,6 +28,6 @@ Same size with **AlphaNet-MPtrj-v1**, trained on OMAT24, and finetuned on sALEX+ The following resources are available in the directory: * **Model Configuration**: `oma.json` -* **Model `state_dict`**: Pre-trained weights can be downloaded from [Figshare](https://ndownloader.figshare.com/files/53851139). +* **Model `state_dict`**: [alex_1212.ckpt](./alex_1212.ckpt). -**Path**: `pretrained_models/OMA` \ No newline at end of file +**Path**: `pretrained_models/OMA` diff --git a/pretrained/AQCAT25/haiku_model_converted/haiku_model.pkl b/pretrained/OMA/alex_1212.ckpt similarity index 59% rename from pretrained/AQCAT25/haiku_model_converted/haiku_model.pkl rename to pretrained/OMA/alex_1212.ckpt index e89d847..35df5cb 100644 Binary files a/pretrained/AQCAT25/haiku_model_converted/haiku_model.pkl and b/pretrained/OMA/alex_1212.ckpt differ diff --git a/requirements.txt b/requirements.txt index f314aec..d4a1e3d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ --extra-index-url https://download.pytorch.org/whl/cu121 - +numpy==1.26.4 +matscipy==1.1.1 torch==2.1.2 torch_geometric lightning @@ -9,4 +10,3 @@ tensorboard --find-links https://pytorch-geometric.com/whl/torch-2.1.2%2Bcu121.html torch_scatter - diff --git a/run_lammps.py b/run_lammps.py new file mode 100644 index 0000000..00e670c --- /dev/null +++ b/run_lammps.py @@ -0,0 +1,116 @@ +import sys +import os +import torch +import lammps as lmp_lib +import lammps.mliap + +try: + import alphanet.infer.lammps_mliap_alphanet +except ImportError: + sys.path.append(os.getcwd()) + import alphanet.infer.lammps_mliap_alphanet + + +_ORIGINAL_JIT_LOAD = torch.jit.load +_ORIGINAL_TORCH_LOAD = torch.load + +GLOBAL_LOADED_MODEL = None + +def hijack_load(f, *args, **kwargs): + """ + 拦截 LAMMPS (或其 Python 接口) 的加载请求。 + 直接返回我们预加载好的 Python 对象。 + """ + f_str = str(f) + print(f"⚡ Intercepted load request for '{f_str}'") + + if GLOBAL_LOADED_MODEL is not None: + print(" => Returning pre-loaded LAMMPS_MLIAP_ALPHANET Object!") + return GLOBAL_LOADED_MODEL + + print(" => Warning: Global model not set. Fallback to original load.") + # 根据文件后缀猜测加载方式 + if f_str.endswith('.pt') or f_str.endswith('.pth'): + return _ORIGINAL_TORCH_LOAD(f, *args, **kwargs) + return _ORIGINAL_JIT_LOAD(f, *args, **kwargs) + +# 替换 torch 的加载函数,以便被 LAMMPS 调用时触发钩子 +torch.jit.load = hijack_load +torch.load = hijack_load + +# --------------------------------------------------------- +# 主程序 +# --------------------------------------------------------- +if __name__ == "__main__": + input_file = "sl.in" # LAMMPS 输入文件 + model_file = "sl.pt" # create_lammps_model.py 生成的文件 + + if not os.path.exists(input_file): + print(f"❌ Error: {input_file} not found!") + sys.exit(1) + + if not os.path.exists(model_file): + print(f"❌ Error: {model_file} not found!") + sys.exit(1) + + print(f"✅ PyTorch {torch.__version__} loaded.") + + # 1. 确定设备 + if torch.cuda.is_available(): + device = torch.device("cuda") + print(f"✅ Using device: CUDA (GPU)") + else: + device = torch.device("cpu") + print(f"⚠️ Using device: CPU") + + # 2. 预加载模型对象 (Pre-load) + print(f"🛠️ Loading Python object from {model_file}...") + try: + # 使用原始的 torch.load 加载保存的 Python 对象 + loaded_object = _ORIGINAL_TORCH_LOAD(model_file, map_location=device) + + # 确保模型及其内部参数都在正确的设备上 + if hasattr(loaded_object, 'model'): + loaded_object.model.to(device) + loaded_object.device = device # 更新对象内部记录的 device + loaded_object.model.eval() + + # 注册到全局变量,供钩子使用 + GLOBAL_LOADED_MODEL = loaded_object + print(" Model loaded and registered successfully.") + + except Exception as e: + import traceback + traceback.print_exc() + print(f"❌ Failed to load model: {e}") + sys.exit(1) + + # 3. 配置 LAMMPS + # -k on g 1: 开启 Kokkos 并使用 1 个 GPU + # neigh half: Kokkos 默认要求,我们已经在 Python 端通过“智能对称化”解决了这个问题 + # newton on: 必须开启,用于跨进程的力通信 + cmd_args = [ + "-k", "on", "g", "1", + "-sf", "kk", + "-pk", "kokkos", "newton", "on", "neigh", "half" + ] + + try: + print(f"🚀 Initializing LAMMPS...") + lmp = lmp_lib.lammps(cmdargs=cmd_args) + + print("🔌 Activating ML-IAP Kokkos interface...") + # 这行代码会触发 C++ 调用 Python 来加载模型 + # 此时会命中我们的 hijack_load,并返回 GLOBAL_LOADED_MODEL + lammps.mliap.activate_mliappy_kokkos(lmp) + + print(f"📂 Executing {input_file}...") + lmp.file(input_file) + + print("🎉 LAMMPS simulation finished successfully.") + + except Exception as e: + import traceback + traceback.print_exc() + print(f"❌ Error during simulation: {e}") + sys.exit(1) \ No newline at end of file diff --git a/setup.py b/setup.py index aa53ddd..2467b04 100644 --- a/setup.py +++ b/setup.py @@ -8,6 +8,7 @@ 'pyfiglet', 'rich', 'ase', + 'matscipy==1.1.1', 'rdkit', 'pydantic', 'scikit-learn', diff --git a/water.json b/water.json deleted file mode 100644 index 657de49..0000000 --- a/water.json +++ /dev/null @@ -1,48 +0,0 @@ -{ - "data": { - "root": "dataset/", - "dataset_name": "t1", - "target": "energy_force", - "train_dataset": "water-train", - "valid_dataset": "water-test", - "test_dataset": "water-test", - "seed": 42 - }, - "model": { - "name": "Alphanet", - "num_layers": 3, - "num_targets": 1, - "output_dim": 1, - "compute_forces": true, - "compute_stress": true, - "use_pbc": true, - "has_dropout_flag": true, - "hidden_channels": 128, - "cutoff": 4, - "num_radial": 8 - }, - "train": { - "epochs": 500, - "batch_size":1, - "accumulation_steps": 16, - "vt_batch_size": 1, - "lr": 0.0005, - "lr_decay_factor": 0.8, - "lr_decay_step_size": 80, - "weight_decay": 0, - "save_dir": "gap", - "log_dir": "", - "disable_tqdm": false, - "scheduler": "cosineannealinglr", - "device": "cuda", - "energy_loss": "mae", - "force_loss": "mae", - "stress_loss": "mae", - "energy_metric": "mae", - "force_metric": "mae", - "energy_coef": 1.0, - "stress_coef": 0.5, - "force_coef": 100.0, - "eval_steps": 1 - } -}