@@ -519,8 +519,9 @@ def _retrieve_data(self, result_list):
519519 result ["Nm" ] = Nm
520520 result ["total" ] = total
521521
522+
522523 @staticmethod
523- def _calc_ldp_recall_score_item (ldp_arr , ca_arr , rot_mtx , trans ):
524+ def _calc_ldp_recall_score_item (ldp_tree , ca_arr , rot_mtx , trans ):
524525 """
525526 Calculate the recall score of LDP points given a rotation matrix and translation vector
526527 All arguments have to be torch tensors on GPU
@@ -529,22 +530,22 @@ def _calc_ldp_recall_score_item(ldp_arr, ca_arr, rot_mtx, trans):
529530 # rot_mtx: torch tensor of shape (3, 3)
530531 # trans: torch tensor of shape (3, )
531532 """
532- import torch
533-
534- # rotated backbone CA
535- rot_backbone_ca = torch .matmul (ca_arr , rot_mtx ) + trans
533+ # import torch
534+ #
535+ # # rotated backbone CA
536+ # rot_backbone_ca = torch.matmul(ca_arr, rot_mtx) + trans
537+ rot_backbone_ca = np .dot (ca_arr , rot_mtx ) + trans
538+ distances , indices = ldp_tree .query (rot_backbone_ca , k = 1 )
539+ coverage = np .sum (distances < 3.0 ) / len (rot_backbone_ca )
540+ return coverage
536541
537- # calculate all pairwise distances
538- dist_mtx = torch .cdist (rot_backbone_ca , ldp_arr , p = 2 )
539-
540- # get distance from the closest LDP point for each CA atom
541- min_dist = torch .min (dist_mtx , dim = 1 ).values
542-
543- # count the coverage of CA atoms within 3.0 angstrom of LDP points in the total amount of CA atoms
544- return (min_dist < 3.0 ).sum ().item () / len (rot_backbone_ca )
545542
546543 def _calc_ldp_recall (self , results , sort = False , progress_bar = True ):
547544 import torch
545+ from scipy .spatial import KDTree
546+ ldp_atoms = self .ldp_atoms .cpu ().numpy ()
547+ ldp_tree = KDTree (ldp_atoms )
548+ bb_ca = self .backbone_ca .cpu ().numpy ()
548549
549550 if not progress_bar :
550551 iter_results = results # calculate for each rotation
@@ -553,10 +554,8 @@ def _calc_ldp_recall(self, results, sort=False, progress_bar=True):
553554 for result in iter_results :
554555 r = R .from_euler ("xyz" , result ["angle" ], degrees = True )
555556 rot_mtx = (r .as_matrix ()).T
556- rot_mtx = torch .from_numpy (rot_mtx ).to (self .device )
557- # rot_mtx = euler_to_mtx(torch.tensor(result["angle"], device=self.device)).t()
558557 result ["ldp_recall" ] = self ._calc_ldp_recall_score_item (
559- self . ldp_atoms , self . backbone_ca , rot_mtx , torch . from_numpy ( result ["real_trans" ]). to ( self . device )
558+ ldp_tree , bb_ca , rot_mtx , result ["real_trans" ]
560559 )
561560
562561 # sort by LDP recall
0 commit comments