Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 39 additions & 6 deletions mdsapt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import yaml

import MDAnalysis as mda
from mdsapt.repair import ChargeStrategy, StandardChargeStrategy

from mdsapt.utils.ensemble import Ensemble

Expand All @@ -37,7 +38,8 @@ class Psi4Config(BaseModel):
Attributes:
method:
The SAPT method to use.
NOTE: You can use any valid Psi4 method, but it might fail if you don't use a SAPT method.
NOTE: You can use any valid Psi4 method, but it might fail if you don't
use a valid one.
basis:
The basis to use.
NOTE: We do not verify if this is a valid basis set or not.
Expand Down Expand Up @@ -70,6 +72,14 @@ class ChargeGuesser(Enum):
STANDARD = 'standard'
RDKIT = 'rdkit'

@property
def charge_strategy(self) -> ChargeStrategy:
if self == ChargeGuesser.STANDARD:
return StandardChargeStrategy()
if self == ChargeGuesser.RDKIT:
raise NotImplemented('RDKit guesser is not implemented yet.')
raise ValueError(f'Unknown charge guesser {self}')


class SimulationConfig(BaseModel):
"""
Expand All @@ -86,13 +96,14 @@ class SimulationConfig(BaseModel):
@dataclass
class TopologySelection:
"""
A configuration item for selecting a single topology. To successfully import a topology,
it must be supported by MDAnalysis.
A configuration item for selecting a single topology. To successfully
import a topology, it must be supported by MDAnalysis.

Attributes:
path: Where the topology file is located.
topology_format: If specified, overrides the format to import with.
charge_overrides: An optional dictionary, where keys are atom numbers and values are their charges.
charge_overrides: An optional dictionary, where keys are atom numbers and
values are their charges.

.. seealso::
`List of topology formats that MDAnalysis supports <https://docs.mdanalysis.org/1.1.1/documentation_pages/topology/init.html>`_
Expand Down Expand Up @@ -321,19 +332,41 @@ def _check_valid_config(cls, values: Dict[str, Any]) -> Dict[str, Any]:
ligands=values.get('ligands'))
missing_selections: List[int] = []

for v in ens.values():
missing_selections += get_invalid_residue_selections(protein_selections, v)
for val in ens.values():
missing_selections += get_invalid_residue_selections(protein_selections, val)

if len(missing_selections) > 0:
errors.append(f'Selected residues are missing from topology: {missing_selections}')

return values

def build_ensemble(self) -> Ensemble:
"""
Builds an ensemble from this configuration data.
"""
return self._build_ensemble(combined_topologies=self.combined_topologies,
protein=self.protein,
ligands=self.ligands)

def build_overrides_dict(self) -> Dict[str, Dict[int, int]]:
if self.combined_topologies is not None and (self.protein, self.ligands) == (None, None):
sels = self.combined_topologies.get_individual_topologies()
return {
str(top.path): top.charge_overrides
for top in sels
}

if self.combined_topologies is None and None not in (self.protein, self.ligands):
ligand_sels = self.ligands.get_individual_topologies()
result = {
str(top.path): top.charge_overrides
for top in ligand_sels
}
result[str(self.protein.path)] = self.protein.charge_overrides
return result

raise ValueError('Must provide `protein` and `ligands` keys, or only `combined_topologies`')

@classmethod
def _build_ensemble(
cls,
Expand Down
102 changes: 80 additions & 22 deletions mdsapt/repair.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,40 +20,92 @@

"""

from typing import Set, Union
from abc import ABC
from typing import Dict, List, NamedTuple, Optional, Set, Union

import logging

import numpy as np

import MDAnalysis as mda
from MDAnalysis.converters.RDKit import atomgroup_to_mol
from MDAnalysis.core.groups import Atom
from MDAnalysis.topology.guessers import guess_types, guess_atom_element

from rdkit import Chem

from pdbfixer import PDBFixer
from simtk.openmm.app import PDBFile

from mdsapt.utils.formal_charge import ElectronInfo, calculate_electron_info, calculate_spin_multiplicity

logger = logging.getLogger('mdsapt.optimizer')


def get_spin_multiplicity(molecule: Chem.Mol) -> int:
"""Returns the spin multiplicity of a :class:`RDKIT.Mol`.
Based on method in http://www.mayachemtools.org/docs/modules/html/code/RDKitUtil.py.html .
class MoleculeElectronInfo(NamedTuple):
charges: List[int]
radicals: List[int]

@property
def total_charge(self):
return sum(self.charges)

@property
def total_radicals(self):
return sum(self.radicals)

@property
def spin_multiplicity(self):
return calculate_spin_multiplicity(self.total_radicals)

:Arguments:
*molecule*
:class:`RDKIT.Mol` object

class ChargeStrategy(ABC):
"""
An interface for specifying a charge/radical electron guessing algorithm.
"""
radical_electrons: int = 0

for atom in molecule.GetAtoms():
radical_electrons += atom.GetNumRadicalElectrons()
def calculate(
self,
ag: mda.AtomGroup,
charge_overrides: Dict[int, int]
) -> MoleculeElectronInfo:
"""
Calculates :class:`MoleculeElectronInfo` for the given :class:`mda.AtomGroup`.

:param ag: The :class:`mda.AtomGroup`
"""
raise NotImplementedError


class StandardChargeStrategy(ChargeStrategy):
"""
The standard charge guessing strategy. This is usually the one you would want to use.
"""

def calculate(
self,
ag: mda.AtomGroup,
charge_overrides: Dict[int, int]
) -> MoleculeElectronInfo:

if charge_overrides is None:
charge_overrides = {}

def calc_atom(atom: Atom) -> ElectronInfo:
bonds = atom.get_connections('bonds', outside=True)
orders = [1 if b.order is None else b.order for b in bonds]

# Assume any atom with an aromatic bond has fc=0 and radicals=0. Return None to signal this.
if 1.5 in orders or 'ar' in orders:
return None

bond_count: int = sum(orders)
return calculate_electron_info(atom.element, bond_count, charge_overrides.get(atom.ix))

# Drop aromatics from the results
results = [a for a in map(calc_atom, ag) if a is not None]

charges = [r.fc for r in results]
radicals = [r.radical for r in results]

total_spin: int = radical_electrons // 2
spin_mult: int = total_spin + 1
return spin_mult
return MoleculeElectronInfo(charges=charges, radicals=radicals)


def is_amino(unv: mda.Universe, resid: int) -> bool:
Expand Down Expand Up @@ -84,7 +136,13 @@ def is_amino(unv: mda.Universe, resid: int) -> bool:
return resname_atr.values[resid - 1] in std_resids


def rebuild_resid(resid: int, residue: mda.AtomGroup, sim_ph: float = 7.0) -> mda.AtomGroup:
def rebuild_resid(
resid: int,
residue: mda.AtomGroup,
charge_strategy: ChargeStrategy,
sim_ph: float = 7.0,
charge_overrides: Optional[Dict[int, int]] = None
) -> mda.AtomGroup:
"""Rebuilds residue by replacing missing protons and adding a new proton
on the C terminus. Raises key error if class
has no value for that optimization."""
Expand Down Expand Up @@ -119,12 +177,12 @@ def get_new_pos(backbone: mda.AtomGroup, length: float) -> np.ndarray:

def protonate_backbone(bkbone: mda.AtomGroup, length: float = 1.128) -> \
Union[mda.AtomGroup, mda.Universe]:
mol_resid = atomgroup_to_mol(bkbone)
i: int = 0
for atom in mol_resid.GetAtoms():
i += atom.GetNumRadicalElectrons()
mol_info: MoleculeElectronInfo = charge_strategy.calculate(
bkbone,
{} if charge_overrides is None else charge_overrides
)

if i > 0:
if mol_info.total_radicals > 0:
backbone = bkbone.select_atoms('backbone')
protonated: mda.Universe = mda.Universe.empty(n_atoms=bkbone.n_atoms + 1,
trajectory=True)
Expand Down
59 changes: 40 additions & 19 deletions mdsapt/sapt.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,37 +28,35 @@

from MDAnalysis.analysis.base import AnalysisBase
from MDAnalysis.topology.guessers import guess_types
from MDAnalysis.converters.RDKit import atomgroup_to_mol
from MDAnalysis.lib.log import ProgressBar

import psi4

from pydantic import ValidationError

from rdkit import Chem

from .config import Config, TrajectoryAnalysisConfig, DockingAnalysisConfig, \
Psi4Config, SysLimitsConfig
from .repair import rebuild_resid, get_spin_multiplicity
from .repair import ChargeStrategy, MoleculeElectronInfo, rebuild_resid
from .utils.ensemble import Ensemble, EnsembleAtomGroup

logger = logging.getLogger('mdsapt.sapt')

MHT_TO_KCALMOL: Final[float] = 627.509


def build_psi4_input_str(resid: int, residue: mda.AtomGroup) -> str:
def build_psi4_input_str(resid: int, residue: mda.AtomGroup, strategy: ChargeStrategy, charge_overrides: Optional[Dict[int, int]] = None) -> str:
"""
Generates Psi4 input file the specified residue. Prepares amino acids for SAPT using
:class:`mdsapt.optimizer.Optimizer`. Adds charge and spin multiplicity to top of cooridnates.
"""
repaired_resid: mda.AtomGroup = rebuild_resid(resid, residue)
rd_mol = atomgroup_to_mol(repaired_resid)
repaired_resid: mda.AtomGroup = rebuild_resid(resid, residue, strategy, charge_overrides=charge_overrides)
result: MoleculeElectronInfo = strategy.calculate(repaired_resid, charge_overrides)

coords: str = f'{Chem.GetFormalCharge(rd_mol)} {get_spin_multiplicity(rd_mol)}'
lines: List[str] = [f'{result.total_charge} {result.spin_multiplicity}']
for atom in repaired_resid.atoms:
coords += f'\n{atom.element} {atom.position[0]} {atom.position[1]} {atom.position[2]}'
return coords
lines.append(f'{atom.element} {atom.position[0]} {atom.position[1]} {atom.position[2]}')

return '\n'.join(lines)


def calc_sapt(psi4_input: str, psi4_cfg: Psi4Config, sys_cfg: SysLimitsConfig,
Expand Down Expand Up @@ -149,6 +147,8 @@ def __init__(self, config: Config, **universe_kwargs) -> None:
for x in ag_sel
}

self._charge_overrides = config.analysis.topology.charge_overrides

self._sel_pairs = config.analysis.pairs
AnalysisBase.__init__(self, self._unv.trajectory)

Expand All @@ -158,7 +158,16 @@ def _prepare(self) -> None:

def _single_frame(self) -> None:
outfile: Optional[str] = None
xyz_dict = {k: build_psi4_input_str(k, self._sel[k]) for k in self._sel.keys()}
xyz_dict = {
k: build_psi4_input_str(
k,
resid,
self._cfg.simulation.charge_guesser.charge_strategy,
charge_overrides=self._charge_overrides
)
for k, resid in self._sel.items()
}

for pair in self._sel_pairs:
coords = xyz_dict[pair[0]] + '\n--\n' + xyz_dict[pair[1]] + '\nunits angstrom'

Expand Down Expand Up @@ -225,6 +234,7 @@ def __init__(self, config: Config) -> None:
raise err

self._ens = self._cfg.analysis.build_ensemble()
self._charge_overrides = self._cfg.analysis.build_overrides_dict()

self._sel_pairs = self._cfg.analysis.pairs

Expand All @@ -240,21 +250,32 @@ def _prepare(self) -> None:
self._pair_names = {pair: f'{pair[0]}-{pair[1]}' for pair in self._sel_pairs}

def _single_system(self) -> None:
xyz_dict = {k: build_psi4_input_str(k, self._sel[k][self._key]) for k in self._sel.keys()}
key_name = self._key_names[self._key]
charge_overrides = self._charge_overrides[self._key]

xyz_dict = {
k: build_psi4_input_str(
k,
self._sel[k][self._key],
self._cfg.simulation.charge_guesser.charge_strategy,
charge_overrides=charge_overrides
)
for k, resid in self._sel.items()
}
outfile: Optional[str] = None

for pair in self._sel_pairs:

coords = xyz_dict[pair[0]] + '\n--\n' + xyz_dict[pair[1]] + '\nunits angstrom'
for a, b in self._sel_pairs:
pair_name = self._pair_names[(a, b)]
coords = xyz_dict[a] + '\n--\n' + xyz_dict[b] + '\nunits angstrom'

logger.info(f'Starting SAPT for {pair}')
logger.info(f'Starting SAPT for %a, %b', a, b)

if self._cfg.psi4.save_output:
outfile = f'sapt_{self._key_names[self._key]}_{self._pair_names[pair]}.out'
outfile = f'sapt_{key_name}_{pair_name}.out'

sapt_result = calc_sapt(coords, self._cfg.psi4, self._cfg.system_limits, outfile)
result: List[Union[str, float]] = \
[self._key_names[self._key], self._pair_names[pair]] + \
[key_name, pair_name] + \
[sapt_result[k] for k in self._SAPT_KEYS]

for i, res in enumerate(result):
Expand All @@ -273,8 +294,8 @@ def run(self) -> None:
trajectory frames.
"""
logger.info("Setting up systems")
self._prepare()
for self._key in ProgressBar(self._ens.keys(), verbose=True):
self._prepare()
self._single_system()
logger.info("Moving to next universe")
logger.info("Finishing up")
Expand Down
Loading