Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ homepage = "https://github.com/nomad-coe/nomad-schema-plugin-simulation-workflow

[tool.uv]
dev-dependencies = [
"nomad-normalizer-plugin-bandstructure",
"nomad-normalizer-plugin-system",
"nomad-normalizer-plugin-bandstructure>=1.0.3",
"nomad-normalizer-plugin-system>=1.0.3",
'mypy>=1.15',
'pytest>= 5.3.0, <8',
'pytest-timeout>=1.4.2',
Expand Down
99 changes: 75 additions & 24 deletions simulationworkflowschema/molecular_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,13 +208,13 @@ def create_empty_universe(
n_segments = 0

if atom_resindex is None:
LOGGER.warn(
LOGGER.warning(
'Residues specified but no atom_resindex given. '
'All atoms will be placed in first Residue.',
)

if residue_segindex is None:
LOGGER.warn(
LOGGER.warning(
'Segments specified but no segment_resindex given. '
'All residues will be placed in first Segment',
)
Expand Down Expand Up @@ -561,6 +561,11 @@ def _get_molecular_bead_groups(
"""
Creates bead groups based on the molecular types as defined by the MDAnalysis universe.
"""
# Input validation
if universe is None:
LOGGER.warning('Universe required to create beads.')
return {}

if not moltypes:
atoms_moltypes = getattr(universe.atoms, 'moltypes', [])
moltypes = np.unique(atoms_moltypes)
Expand All @@ -581,6 +586,7 @@ def _get_molecular_bead_groups(

def calc_molecular_rdf(
universe: MDAnalysis.Universe,
bead_groups: dict[str, BeadGroup],
n_traj_split: int = 10,
n_prune: int = 1,
interval_indices=None,
Expand All @@ -590,15 +596,32 @@ def calc_molecular_rdf(
Calculates the radial distribution functions between for each unique pair of
molecule types as a function of their center of mass distance.

interval_indices: 2D array specifying the groups of the n_traj_split intervals to be averaged
max_mols: the maximum number of molecules per bead group for calculating the rdf, for efficiency purposes.
Parameters
----------
universe : MDAnalysis.Universe
The MDAnalysis universe object.
bead_groups : dict[str, BeadGroup]
Precomputed bead groups for the universe.
n_traj_split : int
Number of intervals to split trajectory into for averaging.
n_prune : int
Pruning parameter for frames.
interval_indices : list or None
2D array specifying the groups of the n_traj_split intervals to be averaged.
max_mols : int
Maximum number of molecules per bead group for calculating the rdf, for efficiency purposes.
"""
# TODO 5k default for max_mols was set after > 50k was giving problems. Should do further testing to see where the appropriate limit should be set.
if bead_groups is None or not bead_groups:
LOGGER.warning('bead_groups required to calculate RDF.')
return {}

if (
not universe
or not universe.trajectory
or universe.trajectory[0].dimensions is None
):
LOGGER.warning('universe required to calculate RDF.')
return {}

n_frames = universe.trajectory.n_frames
Expand All @@ -623,10 +646,9 @@ def calc_molecular_rdf(
if not interval_indices:
interval_indices = [[i] for i in range(n_traj_split)]

bead_groups = _get_molecular_bead_groups(universe)
if not bead_groups:
return bead_groups
moltypes = [moltype for moltype in bead_groups.keys()]
moltypes = list(bead_groups.keys())
del_list = [
i_moltype
for i_moltype, moltype in enumerate(moltypes)
Expand Down Expand Up @@ -830,7 +852,7 @@ def shifted_correlation_average(
>>> indices, data = shifted_correlation(msd, coords)
"""
if window + skip >= 1:
LOGGER.warn(
LOGGER.warning(
'Invalid parameters for shifted_correlation(), resetting to defaults.',
)
window = 0.5
Expand Down Expand Up @@ -868,14 +890,22 @@ def correlate(start_frame):


def calc_molecular_mean_squared_displacements(
universe: MDAnalysis.Universe, max_mols: int = 5000
universe: MDAnalysis.Universe,
bead_groups: dict[str, BeadGroup],
max_mols: int = 5000,
) -> dict[str, Any]:
"""
Calculates the mean squared displacement for the center of mass of each
molecule type.

max_mols: the maximum number of molecules per bead group for calculating the msd, for efficiency purposes.
50k was tested and is very fast and does not seem to have any memory issues.
Parameters
----------
universe : MDAnalysis.Universe
The MDAnalysis universe object.
bead_groups : dict[str, BeadGroup]
Precomputed bead groups for the universe.
max_mols : int
Maximum number of molecules per bead group for calculating the msd, for efficiency purposes.
"""

def parse_jumps(
Expand Down Expand Up @@ -952,31 +982,35 @@ def mean_squared_displacement(start: np.ndarray, current: np.ndarray):
vec = start - current
return (vec**2).sum(axis=1).mean()

if bead_groups is None or not bead_groups:
LOGGER.warning('bead_groups required to calculate MSD.')
return {}

if (
not universe
or not universe.trajectory
or universe.trajectory[0].dimensions is None
):
LOGGER.warning('universe required to calculate MSD.')
return {}

n_frames = universe.trajectory.n_frames
if n_frames < 50:
LOGGER.warn(
LOGGER.warning(
'At least 50 frames required to calculate molecular'
' mean squared displacements, skipping.',
)
return {}

dt = getattr(universe.trajectory, 'dt')
if dt is None:
LOGGER.warn(
LOGGER.warning(
'Universe is missing time step, cannot calculate molecular'
' mean squared displacements, skipping.',
)
return {}
times = np.arange(n_frames) * dt

bead_groups = _get_molecular_bead_groups(universe)
if bead_groups is {}:
return bead_groups

Expand All @@ -985,7 +1019,7 @@ def mean_squared_displacement(start: np.ndarray, current: np.ndarray):
for i_moltype, moltype in enumerate(moltypes):
if len(bead_groups[moltype].positions) > max_mols:
if max_mols > 50000:
LOGGER.warn(
LOGGER.warning(
'Calculating mean squared displacements for more than 50k molecules.'
' Expect long processing times!',
)
Expand All @@ -1000,18 +1034,18 @@ def mean_squared_displacement(start: np.ndarray, current: np.ndarray):
np.random.choice(molnum_types, size=max_mols)
)
atom_indices_rnd = np.concatenate(
[np.where(molnums == molnum)[0] for molnum in molnum_types_rnd]
[moltype_indices[molnums == molnum] for molnum in molnum_types_rnd]
)
selection = ' '.join([str(i) for i in atom_indices_rnd])
selection = f'index {selection}'
ags_moltype_rnd = universe.select_atoms(selection)
bead_groups[moltype] = BeadGroup(ags_moltype_rnd, compound='fragments')
LOGGER.warn(
LOGGER.warning(
'Maximum number of molecules for calculating the msd has been reached.'
' Will make a random selection for calculation.'
)
except Exception:
LOGGER.warn(
LOGGER.warning(
'Error in selecting random molecules for large group when calculating msd. Skipping this molecule type.'
)
del_list.append(i_moltype)
Expand Down Expand Up @@ -1056,12 +1090,22 @@ def calc_radius_of_gyration(
) -> dict[str, Any]:
"""
Calculates the radius of gyration as a function of time for the atoms 'molecule_atom_indices'.

molecule_atom_indices : np.ndarray
The indices of the atoms corresponding to a single molecule for which the Rg will be calculated.
"""
if molecule_atom_indices is None or len(molecule_atom_indices) == 0:
LOGGER.warning(
'molecule_atom_indices is required to calculate radius of gyration'
)
return {}

if (
not universe
or not universe.trajectory
or universe.trajectory[0].dimensions is None
):
LOGGER.warning('universe is None. Cannot calculate radius of gyration.')
return {}
selection = ' '.join([str(i) for i in molecule_atom_indices])
selection = f'index {selection}'
Expand Down Expand Up @@ -1095,7 +1139,13 @@ def calc_molecular_radius_of_gyration(
"""
Calculates the radius of gyration as a function of time for each polymer in the system.
"""
if not system_topology:
if universe is None:
LOGGER.warning('universe required to calculate molecular radius of gyration')
return []
if system_topology is None or not system_topology:
LOGGER.warning(
'system_topology require to calculate molecular radius of gyration.'
)
return []

rg_results = []
Expand Down Expand Up @@ -2536,10 +2586,6 @@ class MolecularDynamicsResults(ThermodynamicsResults):
sub_section=CorrelationFunction.m_def, repeats=True
)

radial_distribution_functions = SubSection(
Comment thread
JFRudzinski marked this conversation as resolved.
sub_section=RadialDistributionFunction.m_def, repeats=True
)

radius_of_gyration = SubSection(sub_section=RadiusOfGyration, repeats=True)

mean_squared_displacements = SubSection(
Expand All @@ -2564,6 +2610,8 @@ def normalize(self, archive, logger):
if universe is None:
return

bead_groups = _get_molecular_bead_groups(universe)

# calculate molecular radial distribution functions
if not self.radial_distribution_functions:
n_traj_split = (
Expand All @@ -2586,6 +2634,7 @@ def normalize(self, archive, logger):
n_prune = int(universe.trajectory.n_frames / len(archive.run[-1].system))
rdf_results = calc_molecular_rdf(
universe,
bead_groups,
n_traj_split=n_traj_split,
n_prune=n_prune,
interval_indices=interval_indices,
Expand All @@ -2598,7 +2647,9 @@ def normalize(self, archive, logger):

# calculate the molecular mean squared displacements
if not self.mean_squared_displacements:
msd_results = calc_molecular_mean_squared_displacements(universe)
msd_results = calc_molecular_mean_squared_displacements(
universe, bead_groups
)
if msd_results:
sec_msds = MeanSquaredDisplacement()
sec_msds._msd_results = msd_results
Expand Down Expand Up @@ -2627,7 +2678,7 @@ def normalize(self, archive, logger):
for rg in rg_results:
n_frames = rg.get('n_frames')
if len(sec_systems) != n_frames:
self.logger.warning(
logger.warning(
'Mismatch in length of system references in calculation and calculated Rg values.'
'Will not store Rg values under calculation section'
)
Expand Down