-
Notifications
You must be signed in to change notification settings - Fork 1
How will we interface with different molecular dynamics (MD) codes? #42
Description
We can write a machine learning potential (MLP) class in python that can be used within a thin driver/class in ase and i-pi. Like
class MlPotenial(Module):
def __init__(self, model):
self.model = model
# ase, ipi both support ase.Atoms, if other types are needed then conversion can be added
def compute(self, frame: ase.Atoms):
self.output = model.forward(frame)
def get_energy(self):
energies = self.output.block().values
# do some additional stuff to have coherent units
return energies
def get_forces(self):
...
# in ase.calculator....
from ase.calculators.calculator import calculator
class MlPotenialCalculator(ase.calculator):
... # loading MlPotenial
# in i-pi drivers
class MlPotenialDriver(Dummy_driver):
... # loading MlPotenialTo me it was not clear how we would use this for any MD code that does not support python codes (like LAMMPS). For the models that use torch, the way-to-go is clearly TorchScript. But I was thinking about the models that are numpy based. How should we handle the interfaces for these ones?
Given our resources and the practicability of other approaches I explored a bit (see below), it seems that for the numpy models the only reasonable approach is to export them also as TorchScript. We nevertheless give the option to completely use numpy for training and running their model in python. But when exporting it for running in a MD library with drivers supporting only low-level code, we convert to a TorchScript compatible format. We will also offer an way to export the model keeping the numpy arrays, but that exported model will only work for MD codes with python drivers as i-pi and ase.
Long-range
For long-range interaction the support for MD codes with python drivers works. To get the MPI support from LAMMPS, it seems to me that we need a different approach in kspace which means a different class and interface.
Alternatives to TorchScript? Using JAX
It seems to me that JAX does not have yet the infrastructure that TorchScript offers with their jit compiled custom operators. I only found on GitHub issue how to load JAX jit compiled function in C++ jax-ml/jax#5337 (comment) looks like a lot of work for just one simple function
Alternatives to TorchScript? Running Python code from C/C++
I just skimmed through this guide https://stackoverflow.com/a/1056057 linking to https://www.linuxjournal.com/article/8497 but it looked like it just opens the door to many more low-level issues.
Link collection how MLP codes interface to MD codes: