Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
8f29053
feat: add lammps interface
Nov 24, 2025
92d3023
Update run_lammps.py
zmyybc Nov 24, 2025
b48797d
Disable model load interception in hijack_load
zmyybc Nov 24, 2025
89db9a0
Update aqcat.json
zmyybc Nov 25, 2025
68abee8
correct_lammps
Dec 28, 2025
73f35ef
Update README.md
zmyybc Dec 28, 2025
09b344e
Add LAMMPS Installation Guide for ML-IAP with KOKKOS
zmyybc Dec 29, 2025
04738f3
Revise README for AlphaNet v0.1.2-beta updates
zmyybc Dec 29, 2025
5d00880
Update LAMMPS running instructions with new commands
zmyybc Dec 29, 2025
2d545f0
Delete pretrained/AQCAT25 directory
zmyybc Dec 29, 2025
eaa84b4
Delete pretrained/MATPES directory
zmyybc Dec 29, 2025
9ee3d45
Delete pretrained/MPtrj directory
zmyybc Dec 29, 2025
2f1bbb6
Add files via upload
zmyybc Dec 29, 2025
6f98478
Add initial dependencies to lmp_requirements.txt
zmyybc Dec 29, 2025
daae474
Update model state_dict reference in README
zmyybc Dec 29, 2025
f5de720
Revise README with updates and installation guide
zmyybc Dec 29, 2025
cdf3918
Update installation instructions in mliap_lammps.md
zmyybc Dec 31, 2025
979cb30
Fix git clone command formatting in documentation
zmyybc Dec 31, 2025
f5ed52b
Create setup.py
zmyybc Jan 2, 2026
af13b10
Add numpy version requirement to requirements.txt
zmyybc Jan 2, 2026
23f09c8
Add files via upload
yckbz Feb 24, 2026
06197bc
fix: wrap atomic positions in calculator
yckbz Feb 24, 2026
71e6249
Update position retrieval to include wrapping
zmyybc Feb 26, 2026
c2963b0
Fix tensor creation warning in calc.py
yckbz Feb 26, 2026
8e7b2f4
Update evaler.py
yckbz Feb 28, 2026
b661ef9
Optimize ASE inference graph reuse
yckbz Mar 12, 2026
ac0a957
Use matscipy sparse neighbor topology for torch inference
yckbz Mar 12, 2026
1ffccf8
Merge branch 'lammps' into lammps
yckbz Mar 12, 2026
c4f2059
Remove max_num_neighbors limit and fix device mismatch bug
yckbz Mar 23, 2026
6caf2c5
Fix non-reproducible train/test split in get_idx_split
yckbz Mar 23, 2026
77edfde
Merge pull request #9 from yckbz/lammps
zmyybc Mar 23, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# --- .gitignore 内容开始 ---

# 1. 忽略 LAMMPS 日志和轨迹文件
log.lammps
*.log
*.lammpstrj

# 2. 忽略数据结构文件
*.data
POSCAR
CH4.txt
*pt
*.pt
*.pkl
# 3. 忽略大模型权重文件 (通常不建议传大文件到 git,除非你需要)

# 4. 忽略 Python 编译缓存和打包文件
__pycache__/
*.egg-info/
build/
dist/
*.zip
lammps/
# --- .gitignore 内容结束 ---
84 changes: 19 additions & 65 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,14 @@

We present **AlphaNet**, a local frame-based equivariant model designed to tackle the challenges of achieving both accurate and efficient simulations for atomistic systems. **AlphaNet** enhances computational efficiency and accuracy by leveraging the local geometric structures of atomic environments through the construction of equivariant local frames and learnable frame transitions. And inspired by Quantum Mechanics, AlphaNet **introduces efficient multi-body message passing by using contraction of matrix product states** rather than common 2-body message passing. Notably, AlphaNet offers one of the best trade-offs between computational efficiency and accuracy among existing models. Moreover, AlphaNet exhibits scalability across a broad spectrum of system and dataset sizes, affirming its versatility.
markdown
## Update Log (v0.1.2)
## Update Log (v0.1.2-beta)

### Major Changes

1. **Added new 2 pretrained models**
- Provide a pretrained model for materials: **AlphaNet-MATPES-r2scan** and our first pretrained model for catlysis: **AlphaNet-AQCAT25**, see them in the [pretrained](./pretrained) folder.
- Users can **convert the checkpoint trained in torch to our JAX model**

2. **Fixed some bugs**
- Support non-periodic boundary conditions in our ase calculator.
- Fixed errors in float64
1. **Add lammps mliap interface**
2. **Slight change of model arch**
3. **Add finetune option**



## Installation Guide
Expand Down Expand Up @@ -84,7 +81,11 @@ alpha-train example.json # use --help to see more functions, like multi-gpu trai
```bash
alpha-conv -i in.ckpt -o out.ckpt # use --help to see more functions
```
3. Evaluate a model and draw diagonal plot:
2. Finetune a converted ckpt:
```bash
alpha-train example.json --finetune /path/to/your.ckpt
```
4. Evaluate a model and draw diagonal plot:
```bash
alpha-eval -c example.json -m /path/to/ckpt # use --help to see more functions
```
Expand Down Expand Up @@ -142,67 +143,17 @@ print(atoms.get_potential_energy())

```

### Using AlphaNet in JAX
1. Installation
```bash
pip install -U --pre jax jaxlib "jax-cuda12-plugin[with-cuda]" jax-cuda12-pjrt -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
```
This is just for reference. JAX installation may be tricky, please get more information in [JAX](https://docs.jax.dev/en/latest/installation.html) and its github issues.

Currently I suggest **version>=0.4 <=0.4.10 or >=0.4.30 <=0.5 or ==0.6.2**

Install flax and haiku
```bash
pip install matscipy
pip install flax
pip install -U dm-haiku
```

2. Converted checkpoints:

See pretrained directory

3. Convert a self-trained ckpt

First from torch to flax:
```bash
python scripts/conv_pt2flax.py #need to modify the path in it.
```
Then from flax to haiku:

```bash
python scripts/flax2haiku.py #need to modify the path in it.
```

4. Performance:

The output (energy forces stress) difference from torch model would below 0.001. I ran speed tests on a 4090 GPU, system size from 4 to 300, and get a **2.5x to 3x** speed up.

Please note jax model need to be compiled first, so the first run could take a few seconds or minutes, but would be pretty fast after that.

## Dataset Download

[The Defected Bilayer Graphene Dataset](https://zenodo.org/records/10374206)

[The Formate Decomposition on Cu Dataset](https://archive.materialscloud.org/record/2022.45)

[The Zeolite Dataset](https://doi.org/10.6084/m9.figshare.27800211)

[The OC dataset](https://opencatalystproject.org/)

[The MPtrj dataset](https://matbench-discovery.materialsproject.org/data)

## Pretrained Models

Current pretrained models:
Current pretrained models (due to the arch changes, previous pretrained models would need update, which will be done asap):

For materials:
- [AlphaNet-MPtrj-v1](pretrained/MPtrj): A model trained on the MpTrj dataset.
- [AlphaNet-oma-v1](pretrained/OMA): A model trained on the OMAT24 dataset, and finetuned on sALEX+MPtrj.
- [AlphaNet-MATPES-r2scan](pretrained/MATPES): A model trained on the MATPES-r2scan dataset.

For surfaces adsorbtion and reactions:
- [AlphaNet-AQCAT25](pretrained/AQCAT25): A model trained on the AQCAT25 dataset.
- [AlphaNet-oma-v1.5](pretrained/OMA): A model trained on the OMAT24 dataset, and finetuned on sALEX+MPtrj.

## Use AlphaNet in LAMMPS

See [mliap_lammps](mliap_lammps.md)

## License

Expand All @@ -222,3 +173,6 @@ We thank all contributors and the community for their support. Please open an is






6 changes: 4 additions & 2 deletions alphanet/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def display_config_table(main_config, runtime_config):
@click.option("--num_devices", type=int, default=1, help="GPUs per node")
@click.option("--resume", is_flag=True, help="Resume training from checkpoint")
@click.option("--ckpt_path", type=click.Path(), default=None, help="Path to checkpoint file")
def main(config, num_nodes, num_devices, resume, ckpt_path):
@click.option("--finetune", type=click.Path(exists=True), default=None, help="Path to pretrained checkpoint for finetuning (resets optimizer)")
def main(config, num_nodes, num_devices, resume, ckpt_path, finetune):

with open(config, "r") as f:
mconfig = json.load(f)
Expand All @@ -67,7 +68,8 @@ def main(config, num_nodes, num_devices, resume, ckpt_path):
"num_nodes": num_nodes,
"num_devices": num_devices,
"resume": resume,
"ckpt_path": ckpt_path
"ckpt_path": ckpt_path,
"finetune_path": finetune
}

display_header()
Expand Down
16 changes: 13 additions & 3 deletions alphanet/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@

import subprocess
import json
import torch
#import torch
from typing import Literal, Dict, Optional
from pydantic_settings import BaseSettings

try:
VERSION = (
subprocess.check_output(["git", "rev-parse", "HEAD"]).decode().strip()
subprocess.check_output(
["git", "rev-parse", "HEAD"],
stderr=subprocess.DEVNULL,
).decode().strip()
)
except Exception:
VERSION = "NA"
Expand All @@ -22,6 +25,7 @@ class TrainConfig(BaseSettings):
batch_size: int = 32
vt_batch_size: int = 32
lr: float = 0.0005
optimizer: str = "radam"
lr_decay_factor: float = 0.5
lr_decay_step_size: int = 150
weight_decay: float = 0
Expand Down Expand Up @@ -86,7 +90,13 @@ class AlphaConfig(BaseSettings):
has_norm_after_flag: bool = False
reduce_mode: str = "sum"
zbl: bool = False
device: torch.device = torch.device('cuda') if torch.cuda.is_available() else torch.device("cpu")
zbl_w: Optional[list] = [0.187,0.3769,0.189,0.081,0.003,0.037,0.0546,0.0715]
zbl_b: Optional[list] = [3.20,1.10,0.102,0.958,1.28,1.14,1.69,5]
zbl_gamma: float = 1.001
zbl_alpha: float = 0.6032
zbl_E2: float = 14.399645478425
zbl_A0: float = 0.529177210903
device: str = "cuda"



Expand Down
91 changes: 91 additions & 0 deletions alphanet/create_lammps_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import argparse
import os
import torch
from pathlib import Path

# Import the AlphaNet model wrapper and config
from alphanet.models.model import AlphaNetWrapper
from alphanet.config import All_Config

# Import the Python-level LAMMPS interface class
try:
from alphanet.infer.lammps_mliap_alphanet import LAMMPS_MLIAP_ALPHANET
except ImportError:
print("Could not import LAMMPS_MLIAP_ALPHANET.")
print("Please ensure 'alphanet/infer/lammps_mliap_alphanet.py' exists.")
exit(1)


def parse_args():
parser = argparse.ArgumentParser(
description="Convert an AlphaNet model to LAMMPS ML-IAP format (Python Pickle)",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--config", "-c", required=True, type=str,
help="Path to the model configuration JSON file",
)
parser.add_argument(
"--checkpoint", "-m", required=True, type=str,
help="Path to the trained model checkpoint (.ckpt)",
)
parser.add_argument(
"--output", "-o", required=True, type=str,
help="Output path to save the model (e.g., alphanet_lammps.pt)",
)
parser.add_argument(
"--device", type=str, default="cpu",
help="Device to load the model on ('cpu' or 'cuda')",
)
parser.add_argument(
"--dtype", type=str, default="float64",
choices=["float32", "float64"],
help="Data type for the model",
)
return parser.parse_args()

def main():
args = parse_args()

device = torch.device(args.device)

print(f"1. Loading configuration from {args.config}...")
config_obj = All_Config().from_json(args.config)

config_obj.model.dtype = "64" if args.dtype == "float64" else "32"

print(f"2. Initializing AlphaNetWrapper (precision: {args.dtype}, device: {args.device})...")
model_wrapper = AlphaNetWrapper(config_obj.model)

print(f"3. Loading weights from {args.checkpoint}...")
ckpt = torch.load(args.checkpoint, map_location=device)

if 'state_dict' in ckpt:
state_dict = {k.replace('model.', ''): v for k, v in ckpt['state_dict'].items()}
model_wrapper.model.load_state_dict(state_dict, strict=False)
else:
model_wrapper.load_state_dict(ckpt, strict=False)

if args.dtype == "float64":
model_wrapper.double()
else:
model_wrapper.float()

model_wrapper.to(device).eval()

print("4. Creating LAMMPS ML-IAP Interface Object...")
lammps_interface_object = LAMMPS_MLIAP_ALPHANET(model_wrapper)

if device.type == 'cuda':
lammps_interface_object.model.cuda()

print(f"5. Saving Python object to {args.output}...")
# Using standard torch.save for Python pickle compatibility
torch.save(lammps_interface_object, args.output)

print("\n--- Success ---")
print(f"Created LAMMPS model file: {args.output}")
print("Usage in LAMMPS: pair_style mliap model/python ...")

if __name__ == "__main__":
main()
8 changes: 5 additions & 3 deletions alphanet/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from tqdm import tqdm
import torch
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
import joblib
from torch_geometric.data import Data, DataLoader, InMemoryDataset, download_url, extract_zip

Expand All @@ -26,7 +27,7 @@ def get_pic_datasets(root, name, config):
test_dataset = test_dataset[test_indices]
else:

dataset = CustomPickleDataset(name=name, root=root)
dataset = CustomPickleDataset(name=name, root=root, config=config)


split_idx = dataset.get_idx_split(len(dataset.data.y), train_size=config.train_size, valid_size=config.valid_size, test_size=config.test_size, seed=config.seed)
Expand Down Expand Up @@ -120,8 +121,9 @@ def process(self):
print('Saving...')
torch.save((data, slices), self.processed_paths[0])

def get_idx_split(self, data_size, train_size=None, valid_size=None, seed=None):
ids = shuffle(list(range(data_size)))
#def get_idx_split(self, data_size, train_size=None, valid_size=None, seed=None):
def get_idx_split(self, data_size, train_size=None, valid_size=None, test_size=None, seed=None):
ids = shuffle(list(range(data_size)), random_state=seed)
if train_size is not None and valid_size is None:
train_idx = ids[:train_size]

Expand Down
4 changes: 2 additions & 2 deletions alphanet/evaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ def plot_force_parity(self, train_loader, val_loader, test_loader, plots_dir=Non
mask = deviation < threshold
preds_force_filtered = preds_force[mask]
targets_force_filtered = targets_force[mask]
force_mae_filtered = 0.5*torch.mean(torch.abs(preds_force_filtered - targets_force_filtered)).item()
force_rmse_filtered = 0.5*torch.sqrt(torch.mean((preds_force_filtered - targets_force_filtered) ** 2)).item()
force_mae_filtered = torch.mean(torch.abs(preds_force_filtered - targets_force_filtered)).item()
force_rmse_filtered = torch.sqrt(torch.mean((preds_force_filtered - targets_force_filtered) ** 2)).item()

plt.scatter(
targets_force_filtered.cpu().numpy(),
Expand Down
Loading