diff --git a/mdsapt/config.py b/mdsapt/config.py index 0053c03..3df8b29 100644 --- a/mdsapt/config.py +++ b/mdsapt/config.py @@ -25,6 +25,7 @@ import yaml import MDAnalysis as mda +from mdsapt.repair import ChargeStrategy, StandardChargeStrategy from mdsapt.utils.ensemble import Ensemble @@ -37,7 +38,8 @@ class Psi4Config(BaseModel): Attributes: method: The SAPT method to use. - NOTE: You can use any valid Psi4 method, but it might fail if you don't use a SAPT method. + NOTE: You can use any valid Psi4 method, but it might fail if you don't + use a valid one. basis: The basis to use. NOTE: We do not verify if this is a valid basis set or not. @@ -70,6 +72,14 @@ class ChargeGuesser(Enum): STANDARD = 'standard' RDKIT = 'rdkit' + @property + def charge_strategy(self) -> ChargeStrategy: + if self == ChargeGuesser.STANDARD: + return StandardChargeStrategy() + if self == ChargeGuesser.RDKIT: + raise NotImplemented('RDKit guesser is not implemented yet.') + raise ValueError(f'Unknown charge guesser {self}') + class SimulationConfig(BaseModel): """ @@ -86,13 +96,14 @@ class SimulationConfig(BaseModel): @dataclass class TopologySelection: """ - A configuration item for selecting a single topology. To successfully import a topology, - it must be supported by MDAnalysis. + A configuration item for selecting a single topology. To successfully + import a topology, it must be supported by MDAnalysis. Attributes: path: Where the topology file is located. topology_format: If specified, overrides the format to import with. - charge_overrides: An optional dictionary, where keys are atom numbers and values are their charges. + charge_overrides: An optional dictionary, where keys are atom numbers and + values are their charges. .. seealso:: `List of topology formats that MDAnalysis supports `_ @@ -321,8 +332,8 @@ def _check_valid_config(cls, values: Dict[str, Any]) -> Dict[str, Any]: ligands=values.get('ligands')) missing_selections: List[int] = [] - for v in ens.values(): - missing_selections += get_invalid_residue_selections(protein_selections, v) + for val in ens.values(): + missing_selections += get_invalid_residue_selections(protein_selections, val) if len(missing_selections) > 0: errors.append(f'Selected residues are missing from topology: {missing_selections}') @@ -330,10 +341,32 @@ def _check_valid_config(cls, values: Dict[str, Any]) -> Dict[str, Any]: return values def build_ensemble(self) -> Ensemble: + """ + Builds an ensemble from this configuration data. + """ return self._build_ensemble(combined_topologies=self.combined_topologies, protein=self.protein, ligands=self.ligands) + def build_overrides_dict(self) -> Dict[str, Dict[int, int]]: + if self.combined_topologies is not None and (self.protein, self.ligands) == (None, None): + sels = self.combined_topologies.get_individual_topologies() + return { + str(top.path): top.charge_overrides + for top in sels + } + + if self.combined_topologies is None and None not in (self.protein, self.ligands): + ligand_sels = self.ligands.get_individual_topologies() + result = { + str(top.path): top.charge_overrides + for top in ligand_sels + } + result[str(self.protein.path)] = self.protein.charge_overrides + return result + + raise ValueError('Must provide `protein` and `ligands` keys, or only `combined_topologies`') + @classmethod def _build_ensemble( cls, diff --git a/mdsapt/repair.py b/mdsapt/repair.py index f8a834a..e0c7da0 100644 --- a/mdsapt/repair.py +++ b/mdsapt/repair.py @@ -20,40 +20,92 @@ """ -from typing import Set, Union +from abc import ABC +from typing import Dict, List, NamedTuple, Optional, Set, Union import logging import numpy as np import MDAnalysis as mda -from MDAnalysis.converters.RDKit import atomgroup_to_mol +from MDAnalysis.core.groups import Atom from MDAnalysis.topology.guessers import guess_types, guess_atom_element -from rdkit import Chem - from pdbfixer import PDBFixer from simtk.openmm.app import PDBFile +from mdsapt.utils.formal_charge import ElectronInfo, calculate_electron_info, calculate_spin_multiplicity + logger = logging.getLogger('mdsapt.optimizer') -def get_spin_multiplicity(molecule: Chem.Mol) -> int: - """Returns the spin multiplicity of a :class:`RDKIT.Mol`. - Based on method in http://www.mayachemtools.org/docs/modules/html/code/RDKitUtil.py.html . +class MoleculeElectronInfo(NamedTuple): + charges: List[int] + radicals: List[int] + + @property + def total_charge(self): + return sum(self.charges) + + @property + def total_radicals(self): + return sum(self.radicals) + + @property + def spin_multiplicity(self): + return calculate_spin_multiplicity(self.total_radicals) - :Arguments: - *molecule* - :class:`RDKIT.Mol` object + +class ChargeStrategy(ABC): + """ + An interface for specifying a charge/radical electron guessing algorithm. """ - radical_electrons: int = 0 - for atom in molecule.GetAtoms(): - radical_electrons += atom.GetNumRadicalElectrons() + def calculate( + self, + ag: mda.AtomGroup, + charge_overrides: Dict[int, int] + ) -> MoleculeElectronInfo: + """ + Calculates :class:`MoleculeElectronInfo` for the given :class:`mda.AtomGroup`. + + :param ag: The :class:`mda.AtomGroup` + """ + raise NotImplementedError + + +class StandardChargeStrategy(ChargeStrategy): + """ + The standard charge guessing strategy. This is usually the one you would want to use. + """ + + def calculate( + self, + ag: mda.AtomGroup, + charge_overrides: Dict[int, int] + ) -> MoleculeElectronInfo: + + if charge_overrides is None: + charge_overrides = {} + + def calc_atom(atom: Atom) -> ElectronInfo: + bonds = atom.get_connections('bonds', outside=True) + orders = [1 if b.order is None else b.order for b in bonds] + + # Assume any atom with an aromatic bond has fc=0 and radicals=0. Return None to signal this. + if 1.5 in orders or 'ar' in orders: + return None + + bond_count: int = sum(orders) + return calculate_electron_info(atom.element, bond_count, charge_overrides.get(atom.ix)) + + # Drop aromatics from the results + results = [a for a in map(calc_atom, ag) if a is not None] + + charges = [r.fc for r in results] + radicals = [r.radical for r in results] - total_spin: int = radical_electrons // 2 - spin_mult: int = total_spin + 1 - return spin_mult + return MoleculeElectronInfo(charges=charges, radicals=radicals) def is_amino(unv: mda.Universe, resid: int) -> bool: @@ -84,7 +136,13 @@ def is_amino(unv: mda.Universe, resid: int) -> bool: return resname_atr.values[resid - 1] in std_resids -def rebuild_resid(resid: int, residue: mda.AtomGroup, sim_ph: float = 7.0) -> mda.AtomGroup: +def rebuild_resid( + resid: int, + residue: mda.AtomGroup, + charge_strategy: ChargeStrategy, + sim_ph: float = 7.0, + charge_overrides: Optional[Dict[int, int]] = None +) -> mda.AtomGroup: """Rebuilds residue by replacing missing protons and adding a new proton on the C terminus. Raises key error if class has no value for that optimization.""" @@ -119,12 +177,12 @@ def get_new_pos(backbone: mda.AtomGroup, length: float) -> np.ndarray: def protonate_backbone(bkbone: mda.AtomGroup, length: float = 1.128) -> \ Union[mda.AtomGroup, mda.Universe]: - mol_resid = atomgroup_to_mol(bkbone) - i: int = 0 - for atom in mol_resid.GetAtoms(): - i += atom.GetNumRadicalElectrons() + mol_info: MoleculeElectronInfo = charge_strategy.calculate( + bkbone, + {} if charge_overrides is None else charge_overrides + ) - if i > 0: + if mol_info.total_radicals > 0: backbone = bkbone.select_atoms('backbone') protonated: mda.Universe = mda.Universe.empty(n_atoms=bkbone.n_atoms + 1, trajectory=True) diff --git a/mdsapt/sapt.py b/mdsapt/sapt.py index f0678fb..1992319 100644 --- a/mdsapt/sapt.py +++ b/mdsapt/sapt.py @@ -28,18 +28,15 @@ from MDAnalysis.analysis.base import AnalysisBase from MDAnalysis.topology.guessers import guess_types -from MDAnalysis.converters.RDKit import atomgroup_to_mol from MDAnalysis.lib.log import ProgressBar import psi4 from pydantic import ValidationError -from rdkit import Chem - from .config import Config, TrajectoryAnalysisConfig, DockingAnalysisConfig, \ Psi4Config, SysLimitsConfig -from .repair import rebuild_resid, get_spin_multiplicity +from .repair import ChargeStrategy, MoleculeElectronInfo, rebuild_resid from .utils.ensemble import Ensemble, EnsembleAtomGroup logger = logging.getLogger('mdsapt.sapt') @@ -47,18 +44,19 @@ MHT_TO_KCALMOL: Final[float] = 627.509 -def build_psi4_input_str(resid: int, residue: mda.AtomGroup) -> str: +def build_psi4_input_str(resid: int, residue: mda.AtomGroup, strategy: ChargeStrategy, charge_overrides: Optional[Dict[int, int]] = None) -> str: """ Generates Psi4 input file the specified residue. Prepares amino acids for SAPT using :class:`mdsapt.optimizer.Optimizer`. Adds charge and spin multiplicity to top of cooridnates. """ - repaired_resid: mda.AtomGroup = rebuild_resid(resid, residue) - rd_mol = atomgroup_to_mol(repaired_resid) + repaired_resid: mda.AtomGroup = rebuild_resid(resid, residue, strategy, charge_overrides=charge_overrides) + result: MoleculeElectronInfo = strategy.calculate(repaired_resid, charge_overrides) - coords: str = f'{Chem.GetFormalCharge(rd_mol)} {get_spin_multiplicity(rd_mol)}' + lines: List[str] = [f'{result.total_charge} {result.spin_multiplicity}'] for atom in repaired_resid.atoms: - coords += f'\n{atom.element} {atom.position[0]} {atom.position[1]} {atom.position[2]}' - return coords + lines.append(f'{atom.element} {atom.position[0]} {atom.position[1]} {atom.position[2]}') + + return '\n'.join(lines) def calc_sapt(psi4_input: str, psi4_cfg: Psi4Config, sys_cfg: SysLimitsConfig, @@ -149,6 +147,8 @@ def __init__(self, config: Config, **universe_kwargs) -> None: for x in ag_sel } + self._charge_overrides = config.analysis.topology.charge_overrides + self._sel_pairs = config.analysis.pairs AnalysisBase.__init__(self, self._unv.trajectory) @@ -158,7 +158,16 @@ def _prepare(self) -> None: def _single_frame(self) -> None: outfile: Optional[str] = None - xyz_dict = {k: build_psi4_input_str(k, self._sel[k]) for k in self._sel.keys()} + xyz_dict = { + k: build_psi4_input_str( + k, + resid, + self._cfg.simulation.charge_guesser.charge_strategy, + charge_overrides=self._charge_overrides + ) + for k, resid in self._sel.items() + } + for pair in self._sel_pairs: coords = xyz_dict[pair[0]] + '\n--\n' + xyz_dict[pair[1]] + '\nunits angstrom' @@ -225,6 +234,7 @@ def __init__(self, config: Config) -> None: raise err self._ens = self._cfg.analysis.build_ensemble() + self._charge_overrides = self._cfg.analysis.build_overrides_dict() self._sel_pairs = self._cfg.analysis.pairs @@ -240,21 +250,32 @@ def _prepare(self) -> None: self._pair_names = {pair: f'{pair[0]}-{pair[1]}' for pair in self._sel_pairs} def _single_system(self) -> None: - xyz_dict = {k: build_psi4_input_str(k, self._sel[k][self._key]) for k in self._sel.keys()} + key_name = self._key_names[self._key] + charge_overrides = self._charge_overrides[self._key] + + xyz_dict = { + k: build_psi4_input_str( + k, + self._sel[k][self._key], + self._cfg.simulation.charge_guesser.charge_strategy, + charge_overrides=charge_overrides + ) + for k, resid in self._sel.items() + } outfile: Optional[str] = None - for pair in self._sel_pairs: - - coords = xyz_dict[pair[0]] + '\n--\n' + xyz_dict[pair[1]] + '\nunits angstrom' + for a, b in self._sel_pairs: + pair_name = self._pair_names[(a, b)] + coords = xyz_dict[a] + '\n--\n' + xyz_dict[b] + '\nunits angstrom' - logger.info(f'Starting SAPT for {pair}') + logger.info(f'Starting SAPT for %a, %b', a, b) if self._cfg.psi4.save_output: - outfile = f'sapt_{self._key_names[self._key]}_{self._pair_names[pair]}.out' + outfile = f'sapt_{key_name}_{pair_name}.out' sapt_result = calc_sapt(coords, self._cfg.psi4, self._cfg.system_limits, outfile) result: List[Union[str, float]] = \ - [self._key_names[self._key], self._pair_names[pair]] + \ + [key_name, pair_name] + \ [sapt_result[k] for k in self._SAPT_KEYS] for i, res in enumerate(result): @@ -273,8 +294,8 @@ def run(self) -> None: trajectory frames. """ logger.info("Setting up systems") + self._prepare() for self._key in ProgressBar(self._ens.keys(), verbose=True): - self._prepare() self._single_system() logger.info("Moving to next universe") logger.info("Finishing up") diff --git a/mdsapt/tests/test_formal_charge.py b/mdsapt/tests/test_formal_charge.py new file mode 100644 index 0000000..ef25f45 --- /dev/null +++ b/mdsapt/tests/test_formal_charge.py @@ -0,0 +1,44 @@ +from typing import Optional +import pytest +from ..utils.formal_charge import calculate_electron_info, ElectronInfo + + +@pytest.mark.parametrize('element, bonds, fc, fields', [ + # fmt: off + + ('H', 1, None, dict(valence=1, fc= 0, lone=0, radical=0)), + + # Typical C bonding + ('C', 2, None, dict(valence=4, fc= 0, lone=2, radical=0)), # Carbene + ('C', 3, None, dict(valence=4, fc=-1, lone=2, radical=0)), # C in Carbon monoxide (CO) + ('C', 4, None, dict(valence=4, fc= 0, lone=0, radical=0)), # Normal C + ('C', 1, 0, dict(valence=4, fc= 0, lone=3, radical=3)), # Carbyne + ('C', 3, -1, dict(valence=4, fc=-1, lone=2, radical=0)), # Carboanion + ('C', 3, 1, dict(valence=4, fc= 1, lone=0, radical=0)), # Carbocation + + ('N', 1, None, dict(valence=5, fc= 0, lone=4, radical=2)), # Nitrene + ('N', 2, None, dict(valence=5, fc=-1, lone=4, radical=0)), # N in Nitric oxide + ('N', 3, None, dict(valence=5, fc= 0, lone=2, radical=0)), # Normal N + ('N', 4, None, dict(valence=5, fc= 1, lone=0, radical=0)), # N in NH4+ + + ('O', 1, None, dict(valence=6, fc=-1, lone=6, radical=0)), # O in Hydroxide ion (OH-) + ('O', 2, None, dict(valence=6, fc= 0, lone=4, radical=0)), # Normal O + ('O', 3, None, dict(valence=6, fc= 1, lone=2, radical=0)), # O in Hydronium (H3O+) + ('O', 1, 0, dict(valence=6, fc= 0, lone=5, radical=1)), # O in Hydroxyl radical (OH) + + ('P', 5, None, dict(valence=5, fc= 0, lone=0, radical=0)), # P in Phosphate + ('P', 6, None, dict(valence=5, fc=-1, lone=0, radical=0)), # P in PF6- + + # Halogens + ('F', 1, None, dict(valence=7, fc= 0, lone=6, radical=0)), + ('Cl', 1, None, dict(valence=7, fc= 0, lone=6, radical=0)), + ('Br', 1, None, dict(valence=7, fc= 0, lone=6, radical=0)), + ('I', 1, None, dict(valence=7, fc= 0, lone=6, radical=0)), + + # fmt: on +]) +def test_calculate_electron_info(element: str, bonds: int, fc: Optional[int], fields: dict): + actual = calculate_electron_info(element, bonds, fc) + + expected = ElectronInfo(element=element, bonds=bonds, **fields) + assert actual == expected diff --git a/mdsapt/tests/test_repair.py b/mdsapt/tests/test_repair.py index 77674be..1ab62d7 100644 --- a/mdsapt/tests/test_repair.py +++ b/mdsapt/tests/test_repair.py @@ -7,7 +7,7 @@ import MDAnalysis as mda -from ..repair import rebuild_resid +from ..repair import StandardChargeStrategy, rebuild_resid resources_dir = Path(__file__).parent / 'testing_resources' @@ -25,7 +25,7 @@ def test_prepare_resids(self) -> None: Tests that a residue is correctly repaired. """ r11: mda.AtomGroup = unv.select_atoms('resid 11') - r11_fixed: mda.AtomGroup = rebuild_resid(11, r11) + r11_fixed: mda.AtomGroup = rebuild_resid(11, r11, StandardChargeStrategy()) assert_array_almost_equal(r11.select_atoms('name CA').positions, r11_fixed.select_atoms('name CA').positions, decimal=3) @@ -44,7 +44,7 @@ def test_prepare_end(self) -> None: Tests last amino acid in chain to ensure it doesn't cap its carboxyl terminus. """ r215 = unv.select_atoms('resid 214') - r215_fixed = rebuild_resid(214, r215) + r215_fixed = rebuild_resid(214, r215, StandardChargeStrategy()) assert len(r215_fixed.select_atoms('name Hc')) == 0 def test_non_amino(self) -> None: @@ -52,5 +52,5 @@ def test_non_amino(self) -> None: Tests non amino acid being passed into rebuild_resid """ r126 = unv.select_atoms('resid 126') - r126_fixed = rebuild_resid(126, r126) + r126_fixed = rebuild_resid(126, r126, StandardChargeStrategy()) assert len(r126_fixed.select_atoms('name Hc')) == 0 diff --git a/mdsapt/utils/ensemble.py b/mdsapt/utils/ensemble.py index 9fe3935..cb8b7aa 100644 --- a/mdsapt/utils/ensemble.py +++ b/mdsapt/utils/ensemble.py @@ -168,10 +168,11 @@ def build_from_dir(cls, ensemble_dir: DirectoryPath, **universe_kwargs) -> 'Ense def build_from_files(cls, topologies: List[Union[str, Path]], **universe_kwargs) -> 'Ensemble': """Constructs an ensemble from a list of files.""" _ens: Dict[str, mda.Universe] = {} - for top in topologies: - name: str = str(top) + for path in topologies: + name: str = str(path) try: - _ens[name] = mda.Universe(name, **universe_kwargs) + unv = mda.Universe(name, **universe_kwargs) + _ens[name] = unv except (mda.exceptions.NoDataError, OSError, ValueError) as err: logger.exception(err) raise err diff --git a/mdsapt/utils/formal_charge.py b/mdsapt/utils/formal_charge.py new file mode 100644 index 0000000..ca6b734 --- /dev/null +++ b/mdsapt/utils/formal_charge.py @@ -0,0 +1,121 @@ +r""" +:mod:`mdsapt.utils.formal_charge` -- Utilities for calculating charge-related properties of an atom +=================================================================================================== + +.. autofunction:: get_fc + +.. autofunction:: get_lone_electrons +""" +from typing import Optional, NamedTuple + + +_ELEMENT_TO_VALENCE = { + 'H': 1, + 'C': 4, + 'N': 5, + 'O': 6, + 'P': 5, + 'S': 6, + + # Halogens + 'F': 7, + 'Cl': 7, + 'Br': 7, + 'I': 7, +} + + +class ElectronInfo(NamedTuple): + """Various electron-related properties of an atom.""" + element: str + valence: int + bonds: int + fc: int + lone: int + radical: int + + +def calculate_electron_info(element: str, bonds: int, fc: Optional[int] = None) -> ElectronInfo: + """ + Calculates various electron-related properties of an atom. + + :param element: the element of the atom + :param bonds: number of bonds attached + :param fc: + Formal charge of the atom, if known. If not provided, it will be automatically calculated. + Our formal charge calculation is optimized for biochemistry simulations, so it might not produce the correct + values for sulfur or phosphorous in more complicated molecules. + """ + + try: + valence = _ELEMENT_TO_VALENCE[element] + except KeyError: + raise ValueError(f"Unsupported element {element}, please manually specify the charge.") + + if fc is None: + fc = _get_fc(element, valence, bonds) + + # Note that we calculate lone back from FC for two reasons: + # - charge is provided by the user + # - we pin the charges of certain element/bond inputs + lone = valence - fc - bonds + + # Calculation of radical electrons. This formula appears to work in most cases. + if bonds == 1 and lone <= 4: + radical = [0, 1, 0, 3, 2][lone] # lookup table + else: + radical = lone % 2 + + return ElectronInfo( + element=element, + valence=valence, + bonds=bonds, + fc=fc, + lone=lone, + radical=radical, + ) + + +def calculate_spin_multiplicity(total_radicals: int) -> int: + """ + Calculates the spin multiplicity of a molecule, given the total number of radical electrons. + + :param total_radicals: The number of radical electrons that exist across the molecule. + """ + total_spin: int = total_radicals // 2 + spin_mult: int = total_spin + 1 + return spin_mult + + +def _get_fc(element: str, valence: int, bonds: int) -> int: + """ + Calculates the formal charge of an atom given its number of bonds. + + This is optimized for biochemistry simulations, so it might not produce the correct + values for sulfur or phosphorous in more complicated molecules. + + :param element: the element of the atom + :param valence: number of valence electrons + :param bonds: number of bonds attached to the atom + """ + if element == 'H': + return 0 + if element == 'P': + if bonds == 5: + return 0 + if bonds == 6: + return -1 + + unpaired = 8 - valence + + if bonds <= unpaired: + lone = valence - bonds + + # Round lone electrons up to next even number + if lone % 2 != 0: + lone += 1 + else: + # Every bond above unpaired donates 2 lone electrons + lone = 8 - 2 * bonds + + return valence - lone - bonds