Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 50 additions & 5 deletions pyinteraph/motion_correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
import sys

def calculate_dccm(fluct):
""" Generates n*n matrix showing the paired movement of the n residues """
"""
Generates n*n matrix showing the paired movement of the n residues
according to dynamical cross-correlation
"""
n_frames, n_res, _ = fluct.shape

# making the columns into frames and then flattening so that each row is a time series of the motion of each residue
Expand All @@ -19,6 +22,49 @@ def calculate_dccm(fluct):
dccm = cov / denom
return dccm

def calculate_lmi(fluct):
"""
Generates n*n matrix showing the paired movement of the n residues
according to linear mutual information
"""
n_frames, n_res, _ = fluct.shape
epsilon = 1e-6

# computing per residue covariance and log-determinants
Ci_list = []
logdet_list = []

for i in range(n_res):
Xi = fluct[:,i,:]
Ci = Xi.T @ Xi / (n_frames - 1) # calculating 3*3 covariance matrix for each residue
Ci_list.append(Ci)
_, logdet = np.linalg.slogdet(Ci + epsilon * np.eye(3)) # log(det(Ci)) and small regularization term added to avoid det(Ci) = 0
logdet_list.append(logdet)

# pairwise calculations
lmi_norm = np.eye(n_res)

for i in range(n_res):
Xi = fluct[:,i,:]

for j in range(i+1, n_res):
Xj = fluct[:,j,:]

C_cross = Xi.T @ Xj / (n_frames - 1) # 3*3 cross-covariance matrix between pairs of residues

Cij = np.block([ # 6*6 joint covariance matrix where diagonals are covariance matrices of the 2 residues
[Ci_list[i], C_cross ], # and non-diagonals are the cross-covariance matrices
[C_cross.T, Ci_list[j]]
]) + epsilon * np.eye(6)

_, logdet_Cij = np.linalg.slogdet(Cij)

raw = max(0.0, 0.5 * (logdet_list[i] + logdet_list[j] - logdet_Cij))
norm = np.sqrt(1.0 - np.exp(-2.0 * raw/3.0)) # Normalizing

lmi_norm[i,j] = lmi_norm[j,i] = norm
return lmi_norm

def main():
parser = argparse.ArgumentParser(description="Compute correlation metrics from MD trajectory")

Expand Down Expand Up @@ -60,7 +106,7 @@ def main():
)

parser.add_argument(
"-c", "--dccm-csv",
"-c", "--csv",
default = "dccm.csv",
help = "Output file name for CSV format (default: dccm.csv)"
)
Expand Down Expand Up @@ -91,18 +137,17 @@ def main():
mean_pos = coords.mean(axis=0)
fluct = coords - mean_pos


# Calculate correlation matrices
if args.method == "dccm": # lmi to be added
result = calculate_dccm(fluct)
elif args.method == "lmi":
raise NotImplementedError("The calculation of LMI is not implemented yet")
result = calculate_lmi(fluct)

np.savetxt(args.output, result, delimiter=" ") # writing to a .dat file

# Writing a CSV file with residue pairs

with open(args.dccm_csv, "w") as outfile:
with open(args.csv, "w") as outfile:
outfile.write("chain1,residue_number1,residue_name1,atom1,"
"chain2,residue_number2,residue_name2,atom2,correlation\n"
)
Expand Down
Loading