-
Notifications
You must be signed in to change notification settings - Fork 0
Correlation Analysis #4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
| import logging | ||
| import sys | ||
|
|
||
| logging.basicConfig( | ||
| stream=sys.stdout, | ||
| level=logging.INFO, | ||
| format="[%(asctime)s] %(levelname)s | %(message)s", | ||
| datefmt="%Y/%m/%d %H:%M:%S" | ||
| ) | ||
|
|
||
| logger = logging.getLogger(__name__) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,16 @@ | ||
| import argparse | ||
| import os | ||
|
|
||
|
|
||
| class ArgumentParserFileExtensionValidation(argparse.FileType): | ||
| parser = argparse.ArgumentParser() | ||
|
|
||
| def __init__(self, valid_extensions, file_name): | ||
| self.valid_extensions = valid_extensions | ||
| self.file_name = file_name | ||
|
|
||
| def validate_file_extension(self): | ||
| given_extension = os.path.splitext(self.file_name)[1][1:] | ||
| if given_extension not in self.valid_extensions: | ||
| self.parser.error(f"Invalid file extension. Please provide a {self.valid_extensions} file") | ||
| return self.file_name |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,42 @@ | ||
| import abc | ||
|
|
||
| import numpy as np | ||
| import itertools | ||
|
|
||
| from pyinteraph.core import logger | ||
| from pyinteraph.correlation.residue_coordinate_matrix import ResidueCoordinateMatrix | ||
|
|
||
|
|
||
| class BaseCorrelationAnalyzer(abc.ABC): | ||
|
|
||
| def __init__(self, residue_coordinate_matrix: ResidueCoordinateMatrix, threshold: float = 0) -> None: | ||
| self.threshold = threshold | ||
| self.residue_coordinate_matrix = residue_coordinate_matrix | ||
| self.n_residues = self.residue_coordinate_matrix.n_residues | ||
| self.coordinates_by_residue = self.residue_coordinate_matrix.coordinates_by_residue | ||
| self.residues_with_atom_number = self.residue_coordinate_matrix.residues_with_atom_number | ||
|
|
||
| def run(self) -> np.ndarray: | ||
| residue_number_combinations = itertools.combinations_with_replacement(range(0, self.n_residues), 2) | ||
| correlation_matrix = np.zeros(shape=(self.n_residues, self.n_residues)) | ||
| for i, j in residue_number_combinations: | ||
| correlation_matrix[i, j] = correlation_matrix[j, i] = self._compute_analyzer( | ||
| self.coordinates_by_residue[i, :], self.coordinates_by_residue[j, :] | ||
| ) | ||
| return correlation_matrix | ||
|
|
||
| @abc.abstractmethod | ||
| def _compute_analyzer(self, residue_i_coords: np.ndarray, residue_j_coords: np.ndarray) -> float: | ||
| raise NotImplementedError() | ||
|
|
||
| @property | ||
| def filtered_correlation_matrix(self): | ||
| correlation_matrix = self.run() | ||
|
|
||
| if self.threshold == 0: | ||
| return correlation_matrix | ||
| return np.where(correlation_matrix > self.threshold, correlation_matrix, 0) | ||
|
|
||
| def to_csv(self, file_name="correlation") -> None: | ||
| np.savetxt(f"{file_name}.csv", self.filtered_correlation_matrix, comments='', delimiter=',', | ||
| header=','.join(self.residues_with_atom_number.keys())) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| import numpy as np | ||
|
|
||
| from pyinteraph.correlation.base_correlation_analyzer import BaseCorrelationAnalyzer | ||
|
|
||
|
|
||
| class DCCMAnalyzer(BaseCorrelationAnalyzer): | ||
| def _compute_analyzer(self, residue_i_coords, residue_j_coords): | ||
| return np.corrcoef(residue_i_coords, residue_j_coords)[1][0] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,12 @@ | ||
| import numpy as np | ||
|
|
||
| from pyinteraph.core import logger | ||
| from pyinteraph.correlation.base_correlation_analyzer import BaseCorrelationAnalyzer | ||
|
|
||
|
|
||
| class LMIAnalyzer(BaseCorrelationAnalyzer): | ||
| def _compute_analyzer(self, residue_i_coords, residue_j_coords): | ||
| pearson_coefficient = np.corrcoef(residue_i_coords, residue_j_coords)[1][0] | ||
| mutual_info = - 3/2 * np.log(1 - pearson_coefficient**2) | ||
| linear_mutual_info = 1 - (np.exp(-2 * (mutual_info/3))) | ||
| return linear_mutual_info |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,49 @@ | ||
| from collections import OrderedDict | ||
| import functools | ||
|
|
||
| import numpy as np | ||
| import MDAnalysis as mda | ||
|
|
||
| from pyinteraph.core import logger | ||
|
|
||
|
|
||
| class ResidueCoordinateMatrix: | ||
| def __init__(self, ref: str, traj: str, atoms: str, backbone: bool) -> None: | ||
| self.ref = ref | ||
| self.traj = traj | ||
| self.selected_atoms = "backbone" if backbone else f"name {atoms.replace(',', ' ')}" | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. usually these measures are calculated on just CAs rather than backbone. Even then, it would be good to have e.g. a default selection option (all CA atoms in protein) or the possibility to give a custom MDAnalysis selection string, akin to what we do in groups for hydrogen bonds |
||
|
|
||
| @property | ||
| @functools.lru_cache() | ||
| def trajectory(self): | ||
| return mda.Universe(self.ref, self.traj) | ||
|
|
||
| @property | ||
| @functools.lru_cache() | ||
| def n_frames(self): | ||
| return self.trajectory.trajectory.n_frames | ||
|
|
||
| @property | ||
| @functools.lru_cache() | ||
| def n_residues(self): | ||
| return self.trajectory.residues.residues.n_residues | ||
|
|
||
| @property | ||
| @functools.lru_cache() | ||
| def residues_with_atom_number(self): | ||
| residues = OrderedDict() | ||
| for res in self.trajectory.residues: | ||
| residues[f"{res.resnum}{res.resname}"] = res.atoms.select_atoms(self.selected_atoms).atoms.ix_array | ||
| return residues | ||
|
|
||
| @property | ||
| @functools.lru_cache() | ||
| def coordinates_by_residue(self) -> np.ndarray: | ||
| traj_by_res = np.zeros(shape=(self.n_residues, self.n_frames)) | ||
| for i, traj in enumerate(self.trajectory.trajectory): | ||
| for res_num in range(0, self.n_residues): | ||
| traj_by_res[res_num][traj.frame] = self.get_atomic_positions_by_geometric_center(traj, res_num) | ||
| return traj_by_res | ||
|
|
||
| def get_atomic_positions_by_geometric_center(self, traj, res_num): | ||
| return np.mean(traj.positions[list(self.residues_with_atom_number.values())[res_num]]) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,59 @@ | ||
| import argparse | ||
| import warnings | ||
|
|
||
| from pyinteraph.core import logger | ||
| from pyinteraph.correlation.dccm import DCCMAnalyzer | ||
| from pyinteraph.correlation.lmi import LMIAnalyzer | ||
| from pyinteraph.correlation.residue_coordinate_matrix import ResidueCoordinateMatrix | ||
| from pyinteraph.core.validate_parser_file_extension import ArgumentParserFileExtensionValidation | ||
|
|
||
| warnings.filterwarnings("ignore") | ||
|
|
||
| def run_dccm(args, backbone): | ||
| residue_coordinate_matrix = ResidueCoordinateMatrix(args.ref, args.traj, args.atoms, backbone) | ||
| DCCMAnalyzer(residue_coordinate_matrix, threshold=args.threshold).to_csv(file_name="dccm") | ||
|
|
||
|
|
||
| def run_lmi(args, backbone): | ||
| residue_coordinate_matrix = ResidueCoordinateMatrix(args.ref, args.traj, args.atoms, backbone) | ||
| LMIAnalyzer(residue_coordinate_matrix, threshold=args.threshold).to_csv(file_name="lmi") | ||
|
|
||
|
|
||
| def main(): | ||
| parser = argparse.ArgumentParser(description="Script to run correlation analysis using DCCM or LMI") | ||
| subparsers = parser.add_subparsers() | ||
|
|
||
| common_args = argparse.ArgumentParser(add_help=False) | ||
| common_args.add_argument('--atoms', type=str, required=False) | ||
| common_args.add_argument('--backbone', action='store_true', required=False) | ||
| common_args.add_argument("--ref", help="Reference file", | ||
| type=lambda file_name: ArgumentParserFileExtensionValidation( | ||
| (".pdb, .gro, .psf, .top, .crd"), file_name).validate_file_extension(), | ||
| required=True) | ||
| common_args.add_argument("--traj", help="a trajectory file", | ||
| type=lambda file_name: ArgumentParserFileExtensionValidation( | ||
| (".trj, .pdb, .xtc, .dcd"), file_name).validate_file_extension(), | ||
| required=True) | ||
| common_args.add_argument("--threshold", type=float, default=0, help="Threshold for the correlation analysis") | ||
|
|
||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think a step is missing here - before calculating DCCM we should best superimpose each trajectory frame to one reference frame (usually the topology) to eliminate roto-translations as much as possible. This fit procedure should also allow to define custom groups to select atoms/residues to fit on (default could be still all CA) |
||
| dccm_parser = subparsers.add_parser("dccm", help="Run DCCM correlation analysis", parents=[common_args]) | ||
| dccm_parser.set_defaults(func=run_dccm) | ||
|
|
||
| lmi_parser = subparsers.add_parser("lmi", help="Run LMI correlation analysis", parents=[common_args]) | ||
| lmi_parser.set_defaults(func=run_lmi) | ||
|
|
||
| args = parser.parse_args() | ||
|
|
||
| if not hasattr(args, 'func'): | ||
| parser.print_help() | ||
| return | ||
|
|
||
| backbone = False | ||
| if (args.atoms and args.backbone) or not args.atoms: | ||
| logger.warning(f"Backbone atoms will be utilized to compute dccm.") | ||
| backbone = True | ||
| args.func(args, backbone) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see the reason to have this function - couldn't this step just be performed in
self.run?also, it would be nice if we could have a low and high threshold, and keep values that are > high and < low. This is because we might want correlations that are larger in absolute value than a certain threshold (and keep the highly negative corrs)