diff --git a/LossLab/__init__.py b/LossLab/__init__.py index b0b9b54..a72c9a3 100644 --- a/LossLab/__init__.py +++ b/LossLab/__init__.py @@ -6,8 +6,18 @@ __version__ = "0.1.0" + +def __getattr__(name): + if name == "CryoEMLLGLoss": + from LossLab.losses.cryoLLGI import CryoEMLLGLoss + + return CryoEMLLGLoss + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + __all__ = [ + "CryoEMLLGLoss", "RealSpaceLoss", - "RefinementEngine", "RefinementConfig", + "RefinementEngine", ] diff --git a/LossLab/cryo/__init__.py b/LossLab/cryo/__init__.py new file mode 100644 index 0000000..86f8b1d --- /dev/null +++ b/LossLab/cryo/__init__.py @@ -0,0 +1,21 @@ +"""Cryo-EM specific helpers (re-exported from canonical locations).""" + + +def __getattr__(name): + if name == "CryoEMLLGLoss": + from LossLab.losses.cryoLLGI import CryoEMLLGLoss + + return CryoEMLLGLoss + if name in ("extract_allatoms", "get_res_names", "position_alignment"): + from LossLab.utils import alignment + + return getattr(alignment, name) + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +__all__ = [ + "CryoEMLLGLoss", + "extract_allatoms", + "get_res_names", + "position_alignment", +] diff --git a/LossLab/cryo/alignment.py b/LossLab/cryo/alignment.py new file mode 100644 index 0000000..c4f1a05 --- /dev/null +++ b/LossLab/cryo/alignment.py @@ -0,0 +1,258 @@ +"""OF3/Rocket atom extraction and alignment helpers for cryo losses.""" + +from __future__ import annotations + +import logging +import numpy as np +import torch + +try: + from openfold3.core.data.resources.token_atom_constants import ( + TOKEN_NAME_TO_ATOM_NAMES, + TOKEN_TYPES_WITH_GAP, + ) +except ModuleNotFoundError: + from openfold3.core.np.token_atom_constants import ( + TOKEN_NAME_TO_ATOM_NAMES, + TOKEN_TYPES_WITH_GAP, + ) +from rocket import coordinates as rk_coordinates +from rocket import utils as rk_utils + + +logger = logging.getLogger(__name__) + +def get_res_names(aatype: np.ndarray) -> np.ndarray: + """Get 3-letter residue names from unified token set.""" + token_names = np.array(TOKEN_TYPES_WITH_GAP) # ["ALA", ..., "G", ..., "DA", ..., "GAP"] + return token_names[np.clip(aatype, 0, len(token_names) - 1)] + +def _restype_to_token_idx(restype, vocab_size: int) -> np.ndarray: + """ + Accepts restype in any of: + - [n_tok] (already indices) + - [n_tok, vocab] (one-hot/logits) + - [1, n_tok, vocab] (batched one-hot/logits) <-- your case + Returns: [n_tok] int64 token indices clipped into vocab. + """ + rt = rk_utils.assert_numpy(restype) + + # peel leading singleton dims (batch, etc.) + while rt.ndim > 2 and rt.shape[0] == 1: + rt = rt[0] + + if rt.ndim == 1: + idx = rt.astype(np.int64) + elif rt.ndim == 2 and rt.shape[-1] == vocab_size: + idx = rt.argmax(-1).astype(np.int64) + else: + raise ValueError(f"Unexpected restype shape after squeeze: {rt.shape}") + + return np.clip(idx.reshape(-1), 0, vocab_size - 1) + + +def get_res_names(restype: np.ndarray) -> np.ndarray: + """Get 3-letter residue/token names from unified token set.""" + token_names = np.asarray(TOKEN_TYPES_WITH_GAP, dtype=object) + token_idx = _restype_to_token_idx(restype, len(token_names)) + return token_names[token_idx] + +def extract_allatoms(outputs, feats, cra_name_sfc: list): + token_names = np.asarray(TOKEN_TYPES_WITH_GAP, dtype=object) + atom_names_dict = TOKEN_NAME_TO_ATOM_NAMES + + # restype: (1, n_tok, vocab) in your case + token_idx = _restype_to_token_idx(feats["restype"], len(token_names)) + res_names = token_names[token_idx] + n_res = int(res_names.shape[0]) + chain_resid = np.asarray([f"A-{i}-" for i in range(n_res)], dtype=object) + + # atom-flat predicted coords WITH GRAD + atom_pos_pred = outputs["atom_positions_predicted"][0].squeeze(0) # [n_atom, 3] + if atom_pos_pred.ndim != 2 or atom_pos_pred.shape[-1] != 3: + raise ValueError(f"Unexpected atom_positions_predicted shape: {tuple(atom_pos_pred.shape)}") + n_atom = int(atom_pos_pred.shape[0]) + + # atom-flat mask + atom_mask = rk_utils.assert_numpy(feats["atom_mask"][0]).reshape(-1) # [n_atom] + if atom_mask.shape[0] != n_atom: + raise ValueError(f"atom_mask {atom_mask.shape[0]} != n_atom {n_atom}") + + # per-atom pLDDT (no grad needed) + pl_logits = outputs["plddt_logits"][0].squeeze(0).to(torch.float32) # [n_atom, n_bins] + probs = torch.softmax(pl_logits, dim=-1) + bin_centers = torch.linspace(0.0, 1.0, steps=probs.shape[-1], device=probs.device) + plddt_atom = ((probs * bin_centers).sum(dim=-1) * 100.0).detach().cpu().numpy() # [n_atom] + + cra_names = [] + atom_positions = [] + plddts = [] + + global_atom_index = 0 + for i in range(n_res): + rn = res_names[i] + resname = rn.decode() if isinstance(rn, (bytes, np.bytes_)) else str(rn) + + atoms = atom_names_dict.get(resname, []) + for aname in atoms: + if global_atom_index >= n_atom: + raise IndexError("global_atom_index ran past atom_positions_predicted") + + if atom_mask[global_atom_index] == 0: + global_atom_index += 1 + continue + + cra_names.append(f"{chain_resid[i]}{resname}-{aname}") + # IMPORTANT: keep grad! + atom_positions.append(atom_pos_pred[global_atom_index].to(torch.float32)) + plddts.append(float(plddt_atom[global_atom_index])) + global_atom_index += 1 + + if not atom_positions: + raise RuntimeError("No atoms extracted (empty atom_positions).") + + positions_atom = torch.stack(atom_positions, dim=0) # [n_kept, 3], REQUIRES_GRAD if input does + plddt_atom_t = torch.tensor(plddts, dtype=torch.float32, device=positions_atom.device) + + # reorder to SFC topology + idx_map = {c: k for k, c in enumerate(cra_names)} + missing = [c for c in cra_name_sfc if c not in idx_map] + if missing: + raise AssertionError(f"Topology mismatch; missing {len(missing)} CRA (showing 10): {missing[:10]}") + + reorder_index = [idx_map[c] for c in cra_name_sfc] + + return positions_atom[reorder_index], plddt_atom_t[reorder_index] + +from LossLab.utils import geometry as geom +import numpy as np +import torch + +def _anchor_idx_from_cra(cra_name): + # prefer protein backbone if present; else RNA backbone; else all atoms + atoms = np.array([str(c).split("-")[-1].replace("*", "'") for c in cra_name], dtype=object) + + prot = np.isin(atoms, ["N", "CA", "C"]) + if prot.any(): + return np.where(prot)[0] + + rna = np.isin(atoms, ["P","OP1","OP2","O5'","C5'","C4'","O4'","C3'","O3'","C2'","C1'"]) + if rna.any(): + return np.where(rna)[0] + + return np.arange(len(cra_name), dtype=np.int64) + + +def position_alignment( + rollout_output, + batch, + best_pos, + exclude_res, + cra_name, + domain_segs=None, + reference_bfactor=None, # optional per-atom B-factors (Tensor) +): + """ + Minimal alignment: + - coords: Tensor [..., N, 3] or [N, 3] + - best_pos: reference coords Tensor [..., N, 3] or [N, 3] + - exclude_res: indices to mask out (list/array/Tensor), optional + - plddt OR reference_bfactor provide weights; if neither, uniform weights + Returns: + aligned_xyz (torch.float32 [N,3]), + plddts_res (np.float32 [N]), + pseudo_Bs (torch.float32 [N]) + """ + coords, plddts = extract_allatoms(rollout_output, batch, cra_name) + pseudo_Bs = rk_utils.plddt2pseudoB_pt(plddts) + if reference_bfactor is None: + pseudoB_np = rk_utils.assert_numpy(pseudo_Bs) + cutoff1 = np.quantile(pseudoB_np, 0.3) + cutoff2 = cutoff1 * 1.5 + weights = rk_utils.weighting(pseudoB_np, cutoff1, cutoff2) + else: + assert reference_bfactor.shape == pseudo_Bs.shape, ( + "Reference bfactor should have same shape as model bfactor!" + ) + reference_bfactor_np = rk_utils.assert_numpy(reference_bfactor) + cutoff1 = np.quantile(reference_bfactor_np, 0.3) + cutoff2 = cutoff1 * 1.5 + weights = rk_utils.weighting(reference_bfactor_np, cutoff1, cutoff2) + + # --- alignment --- + aligned_xyz = rk_coordinates.iterative_kabsch_alignment( + coords, best_pos, cra_name, + weights=weights, + exclude_res=exclude_res, + domain_segs=domain_segs, + ) + mask = cra_exclude_mask(cra_name, exclude_res=exclude_res) + + rmsd_pre = weighted_rmsd(coords, best_pos, w=weights, mask=mask) + rmsd_post = weighted_rmsd(aligned_xyz, best_pos, w=weights, mask=mask) + + logger.debug("RMSD pre-align: %.3f post-align: %.3f", rmsd_pre, rmsd_post) + + if rmsd_post > rmsd_pre + 1e-3: + logger.warning("RMSD increased after alignment — check ordering / weights / mask.") + + return aligned_xyz, plddts, pseudo_Bs.detach() + +#debug +def _to_numpy(x): + if isinstance(x, np.ndarray): + return x + if torch.is_tensor(x): + return x.detach().cpu().numpy() + return np.asarray(x) + +def weighted_rmsd(xyz_a, xyz_b, w=None, mask=None): + """ + xyz_*: [N,3] + w: [N] (optional) weights per atom (can be unnormalized) + mask: [N] bool (optional) atoms to include + """ + a = _to_numpy(xyz_a).astype(np.float64, copy=False) + b = _to_numpy(xyz_b).astype(np.float64, copy=False) + assert a.shape == b.shape and a.ndim == 2 and a.shape[1] == 3, (a.shape, b.shape) + + if mask is None: + mask = np.isfinite(a).all(-1) & np.isfinite(b).all(-1) + else: + mask = mask & np.isfinite(a).all(-1) & np.isfinite(b).all(-1) + + if w is None: + diff2 = ((a[mask] - b[mask]) ** 2).sum(-1) # [M] + return float(np.sqrt(diff2.mean())) + else: + ww = _to_numpy(w).reshape(-1).astype(np.float64, copy=False) + assert ww.shape[0] == a.shape[0], (ww.shape, a.shape) + ww = ww[mask] + ww_sum = ww.sum() + if ww_sum <= 0: + raise ValueError(f"Non-positive weight sum after masking: {ww_sum}") + ww = ww / ww_sum + diff2 = ((a[mask] - b[mask]) ** 2).sum(-1) + return float(np.sqrt((ww * diff2).sum())) + +def cra_exclude_mask(cra_name, exclude_res=None): + """ + cra_name entries look like 'A-12-ALA-CA' (from your construction). + exclude_res: iterable of residue indices (ints) to exclude. + Returns: [N] bool mask (True = keep) + """ + if not exclude_res: + return None + ex = set(int(r) for r in exclude_res) + keep = np.ones(len(cra_name), dtype=bool) + for i, cra in enumerate(cra_name): + # 'A-12-ALA-CA' -> parts[1] == '12' + parts = str(cra).split("-") + if len(parts) >= 2: + try: + resid = int(parts[1]) + if resid in ex: + keep[i] = False + except ValueError: + pass + return keep \ No newline at end of file diff --git a/LossLab/losses/__init__.py b/LossLab/losses/__init__.py index ceac9b1..48621c4 100644 --- a/LossLab/losses/__init__.py +++ b/LossLab/losses/__init__.py @@ -3,5 +3,24 @@ from LossLab.losses.base import BaseLoss from LossLab.losses.mse import MSECoordinatesLoss from LossLab.losses.realspace import RealSpaceLoss +from LossLab.losses.saxs import DebyeLoss, DebyeRawLoss, debye_intensity, load_saxs_data -__all__ = ["RealSpaceLoss", "MSECoordinatesLoss", "BaseLoss"] + +def __getattr__(name): + if name == "CryoEMLLGLoss": + from LossLab.losses.cryoLLGI import CryoEMLLGLoss + + return CryoEMLLGLoss + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +__all__ = [ + "BaseLoss", + "CryoEMLLGLoss", + "DebyeLoss", + "DebyeRawLoss", + "MSECoordinatesLoss", + "RealSpaceLoss", + "debye_intensity", + "load_saxs_data", +] diff --git a/LossLab/losses/cryoLLGI.py b/LossLab/losses/cryoLLGI.py new file mode 100644 index 0000000..bf9e4a1 --- /dev/null +++ b/LossLab/losses/cryoLLGI.py @@ -0,0 +1,370 @@ +"""Cryo-EM LLG loss — mirrors ROCKET ``refinement_cryoem.py`` exactly. + +Pipeline per step: + 1. Extract all-atom positions from OF3 rollout (with grad). + 2. Kabsch-align to reference pose (no_grad), apply STE for gradient flow. + 3. Set pseudo-B-factors on both SFCs from pLDDT. + 4. Compute per-atom RSCC → update B-factors on both SFCs. + 5. Rigid-body refinement (RBR) via quaternion LBFGS on the RBR LLGloss. + 6. Score with ``-LLG`` (main LLGloss) + B-factor-weighted L2 penalty. +""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Any + +import gemmi +import numpy as np +import torch +from rocket import coordinates as rk_coordinates +from rocket import utils as rk_utils +from rocket.cryo import structurefactors as cryo_sf +from rocket.cryo import targets as cryo_targets +from SFC_Torch import PDBParser + +from LossLab.losses.base import BaseLoss +from LossLab.utils.alignment import ( + cra_exclude_mask, + extract_allatoms, + position_alignment, + weighted_rmsd, +) + +logger = logging.getLogger(__name__) + +# Suppress unused-import warning: np is used implicitly via rk_utils +_np = np + + +class CryoEMLLGLoss(BaseLoss): + """Cryo-EM negative-LLG loss matching ROCKET's cryo refinement loop. + + Two ``SFcalculator`` / ``LLGloss`` pairs are maintained — one for the + main LLG scoring, one for rigid-body refinement — exactly as in + ``rocket.refinement_cryoem.run_cryoem_refinement``. + """ + + def __init__( + self, + input_cif: str | Path, + input_mtz: str | Path, + device: torch.device | str = "cuda:0", + n_bins: int = 20, + e_label: str = "Emean", + phie_label: str = "PHIEmean", + l2_weight: float = 1e-10, + num_batch: int = 1, + sub_ratio: float = 1.0, + sfc_scale: bool = False, + rbr_lbfgs: bool = True, + rbr_lbfgs_lr: float = 150.0, + rbr_verbose: bool = False, + domain_segs: list[int] | None = None, + exclude_res: list[int] | None = None, + fixed_b_factor: float | None = None, + ) -> None: + super().__init__(device) + self.input_cif = str(input_cif) + self.input_mtz = str(input_mtz) + self.n_bins = int(n_bins) + self.e_label = e_label + self.phie_label = phie_label + self.l2_weight = float(l2_weight) + self.num_batch = int(num_batch) + self.sub_ratio = float(sub_ratio) + self.sfc_scale = bool(sfc_scale) + self.rbr_lbfgs = bool(rbr_lbfgs) + self.rbr_lbfgs_lr = float(rbr_lbfgs_lr) + self.rbr_verbose = bool(rbr_verbose) + self.domain_segs = domain_segs + self.exclude_res = exclude_res + self.fixed_b_factor = fixed_b_factor + + # --- main SFC + LLGloss (for final scoring) --- + self.mtz = rk_utils.load_mtz(self.input_mtz) + + # Strip hydrogens: OF3 only predicts heavy atoms, so the SFC + # topology must exclude H to match the model output. + pdb_model = self._load_pdb_no_h(self.input_cif) + + self.sfc = cryo_sf.initial_cryoSFC( + pdb_model, + self.mtz, + self.e_label, + self.phie_label, + self.device, + self.n_bins, + ) + if hasattr(self.sfc, "scales") and self.sfc.scales is not None: + self.sfc.scales = self.sfc.scales.detach() + self.target = cryo_targets.LLGloss(self.sfc, self.input_mtz) + + # --- RBR SFC + LLGloss (for rigid-body refinement) --- + pdb_model_rbr = self._load_pdb_no_h(self.input_cif) + self.sfc_rbr = cryo_sf.initial_cryoSFC( + pdb_model_rbr, + self.mtz, + self.e_label, + self.phie_label, + self.device, + self.n_bins, + ) + if hasattr(self.sfc_rbr, "scales") and self.sfc_rbr.scales is not None: + self.sfc_rbr.scales = self.sfc_rbr.scales.detach() + self.target_rbr = cryo_targets.LLGloss(self.sfc_rbr, self.input_mtz) + + # Reference pose and B-factor weights + self.reference_pos = self.sfc.atom_pos_orth.detach().clone() + self.best_pos = self.reference_pos.clone() + self._init_bfactors() + + # Override with fixed B-factor if requested + if self.fixed_b_factor is not None: + n_atoms = int(self.sfc.atom_pos_orth.shape[0]) + fixed_bs = torch.full( + (n_atoms,), + self.fixed_b_factor, + dtype=torch.float32, + device=self.device, + ) + self.sfc.atom_b_iso = fixed_bs.clone() + self.sfc_rbr.atom_b_iso = fixed_bs.clone() + self.reference_b_iso = fixed_bs.clone() + self.init_pos_bfactor = fixed_bs.clone() + logger.info( + "Using fixed B-factor = %.1f for all atoms", + self.fixed_b_factor, + ) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + @staticmethod + def _load_pdb_no_h(cif_path: str) -> PDBParser: + """Load CIF/PDB, strip hydrogens and ligands/waters.""" + st = gemmi.read_structure(str(cif_path)) + st.remove_hydrogens() + st.remove_ligands_and_waters() + return PDBParser(st) + + def _init_bfactors(self) -> None: + """Compute initial RSCC-derived B-factors and alignment weights.""" + self.gridsize = self.mtz.get_reciprocal_grid_size(sample_rate=3.0) + self.rg = torch.tensor( + rk_utils.g_function_np(2 * self.sfc.dmin, 1 / self.sfc.dHKL), + device=self.device, + dtype=torch.float32, + ) + self.uc_volume = self.sfc.unit_cell.volume + + dobs_values = self.mtz["Dobs"].to_numpy() + sf_np = self.mtz.to_structurefactor( + sf_key=self.e_label, + phase_key=self.phie_label, + ).to_numpy() + self.rscc_reference_fmap = torch.tensor( + dobs_values * sf_np, + device=self.device, + dtype=torch.complex64, + ) + + fprotein = self.sfc.calc_fprotein(Return=True).to(torch.complex64) + cc_full = rk_utils.get_rscc_from_Fmap( + fprotein.detach(), + self.rscc_reference_fmap, + self.sfc.HKL_array, + self.gridsize, + self.rg, + self.uc_volume, + ) + atom_cc_full = rk_utils.interpolate_grid_points( + cc_full, + self.sfc.atom_pos_frac.detach().to(torch.float32).cpu().numpy(), + ) + rscc_b = torch.tensor( + rk_utils.get_b_from_CC(atom_cc_full, self.sfc.dmin), + dtype=torch.float32, + device=self.device, + ) + self.sfc.atom_b_iso = rscc_b.detach().clone() + self.sfc_rbr.atom_b_iso = rscc_b.detach().clone() + self.reference_b_iso = rscc_b.detach().clone() + self.init_pos_bfactor = rscc_b.detach().clone() + + cutoff1 = torch.quantile(self.reference_b_iso, 0.30) + cutoff2 = cutoff1 * 1.5 + bw = rk_utils.weighting_torch(self.reference_b_iso, cutoff1, cutoff2) + self.bfactor_weights = bw / bw.sum().clamp_min(1e-8) + + def _update_bfactors_rscc(self, aligned_xyz: torch.Tensor) -> None: + """Recompute RSCC B-factors — updates both SFCs.""" + self.sfc_rbr.atom_pos_orth = aligned_xyz.detach().clone() + fprotein = self.sfc_rbr.calc_fprotein(Return=True).to(torch.complex64) + ccmap = rk_utils.get_rscc_from_Fmap( + fprotein.detach(), + self.rscc_reference_fmap, + self.sfc_rbr.HKL_array, + self.gridsize, + self.rg, + self.uc_volume, + ) + atom_cc = rk_utils.interpolate_grid_points( + ccmap, + self.sfc_rbr.atom_pos_frac.detach().to(torch.float32).cpu().numpy(), + ) + rscc_b = torch.tensor( + rk_utils.get_b_from_CC(atom_cc, self.sfc_rbr.dmin), + dtype=torch.float32, + device=self.device, + ) + self.sfc_rbr.atom_b_iso = rscc_b.detach().clone() + self.sfc.atom_b_iso = rscc_b.detach().clone() + + # ------------------------------------------------------------------ + # Main entry points + # ------------------------------------------------------------------ + + def compute_from_rollout( + self, + rollout_output: dict, + batch: dict, + exclude_res=None, + domain_segs=None, + return_metadata: bool = False, + return_coordinates: bool = False, + ) -> torch.Tensor | tuple[torch.Tensor, dict[str, Any]]: + """Compute loss from a diffusion rollout output dict. + + Follows ROCKET ``refinement_cryoem.py`` step-for-step: + align -> pseudo-B -> RSCC B-update -> RBR -> -LLG + L2. + """ + _exclude = exclude_res or self.exclude_res + _domain = domain_segs or self.domain_segs + + # 1. Extract atom positions in SFC order (keeps grad) + x0, _ = extract_allatoms(rollout_output, batch, self.sfc.cra_name) + x0 = x0.to(self.device) + + # 2. Kabsch align to best_pos + aligned_xyz, plddts_res, pseudo_bs = position_alignment( + rollout_output=rollout_output, + batch=batch, + best_pos=self.best_pos, + exclude_res=_exclude, + cra_name=self.sfc.cra_name, + domain_segs=_domain, + reference_bfactor=self.init_pos_bfactor, + ) + aligned_xyz = aligned_xyz.to(self.device) + pseudo_bs = pseudo_bs.to(self.device) + + # 3. Set B-factors + if self.fixed_b_factor is not None: + n_atoms = int(self.sfc.atom_pos_orth.shape[0]) + fixed_bs = torch.full( + (n_atoms,), + self.fixed_b_factor, + dtype=torch.float32, + device=self.device, + ) + self.sfc.atom_b_iso = fixed_bs.clone() + self.sfc_rbr.atom_b_iso = fixed_bs.clone() + else: + self.sfc.atom_b_iso = pseudo_bs.detach().clone() + self.sfc_rbr.atom_b_iso = pseudo_bs.detach().clone() + self._update_bfactors_rscc(aligned_xyz) + + # 4. Optional SFC scale on RBR SFC + if self.sfc_scale: + self.sfc_rbr.calc_fprotein() + self.sfc_rbr.get_scales_adam( + lr=0.01, + n_steps=10, + sub_ratio=0.7, + initialize=False, + ) + + # 5. Rigid-body refinement (RBR) via quaternion LBFGS + optimized_xyz, _ = rk_coordinates.rigidbody_refine_quat( + aligned_xyz, + self.target_rbr, + self.sfc_rbr.cra_name, + domain_segs=_domain, + lbfgs=self.rbr_lbfgs, + lbfgs_lr=self.rbr_lbfgs_lr, + verbose=self.rbr_verbose, + ) + + # 6. Update main SFC position + self.sfc.atom_pos_orth = optimized_xyz.detach().clone() + + # 7. Score: -LLG + l_llg = -self.target( + optimized_xyz.to(torch.float32), + bin_labels=None, + num_batch=self.num_batch, + sub_ratio=self.sub_ratio, + update_scales=self.sfc_scale, + ) + + # 8. B-factor weighted L2 penalty + l2_loss = torch.sum( + self.bfactor_weights.unsqueeze(-1) + * (optimized_xyz - self.reference_pos) ** 2 + ) + loss = l_llg + self.l2_weight * l2_loss + + if not return_metadata: + return loss + + # Metadata + try: + rmsd_mask = cra_exclude_mask(self.sfc.cra_name, exclude_res=_exclude) + w = self.bfactor_weights.detach() + ref = self.reference_pos.detach() + rmsd_raw = float(weighted_rmsd(x0.detach(), ref, w=w, mask=rmsd_mask)) + rmsd_aligned = float( + weighted_rmsd(aligned_xyz.detach(), ref, w=w, mask=rmsd_mask) + ) + rmsd_rbr = float( + weighted_rmsd(optimized_xyz.detach(), ref, w=w, mask=rmsd_mask) + ) + except Exception: + rmsd_raw = rmsd_aligned = rmsd_rbr = float("nan") + + md: dict[str, Any] = { + "mean_plddt": float(plddts_res.mean()), + "mean_pseudo_b": float(pseudo_bs.mean()), + "l_llg": float(l_llg.detach()), + "l2_loss": float(l2_loss.detach()), + "loss": float(loss.detach()), + "rmsd_raw_to_ref": rmsd_raw, + "rmsd_aligned_to_ref": rmsd_aligned, + "rmsd_rbr_to_ref": rmsd_rbr, + "optimized_xyz": optimized_xyz.detach().clone(), + } + if return_coordinates: + md["coords_raw"] = x0.detach() + md["coords_aligned"] = aligned_xyz.detach() + md["coords_rbr"] = optimized_xyz.detach() + md["cra_name_sfc"] = list(self.sfc.cra_name) + return loss, md + + def update_best_pos(self, new_best_xyz: torch.Tensor) -> None: + """Update Kabsch alignment target to the best coordinates.""" + self.best_pos = new_best_xyz.detach().clone().to(self.device) + + def compute( + self, + coordinates: torch.Tensor, + **kwargs: Any, + ) -> torch.Tensor | tuple[torch.Tensor, dict]: + """Score raw coordinates directly (no alignment/rollout).""" + return_metadata = kwargs.pop("return_metadata", False) + loss = -self.target(coordinates.to(torch.float32), **kwargs) + if return_metadata: + return loss, {} + return loss diff --git a/LossLab/losses/mse.py b/LossLab/losses/mse.py index e9b8149..5eb59d9 100644 --- a/LossLab/losses/mse.py +++ b/LossLab/losses/mse.py @@ -2,10 +2,16 @@ from __future__ import annotations +import logging +from typing import Any + import numpy as np import torch from LossLab.losses.base import BaseLoss +from LossLab.utils.sequence import compute_common_indices + +logger = logging.getLogger(__name__) class MSECoordinatesLoss(BaseLoss): @@ -42,6 +48,7 @@ def __init__( if reference_pdb is not None: self.reference_cra = list(reference_pdb.cra_name) + self._reference_sequence = getattr(reference_pdb, "sequence", None) self.index_moving: np.ndarray | None = None self.index_reference: np.ndarray | None = None @@ -51,50 +58,14 @@ def __init__( def set_moving_pdb(self, moving_pdb) -> None: if self.reference_cra is None: raise ValueError("reference_pdb is required to set moving_pdb") - self.index_moving, self.index_reference = self._compute_common_indices( - moving_pdb.cra_name, + self.index_moving, self.index_reference = compute_common_indices( + list(moving_pdb.cra_name), self.reference_cra, self.selection, + moving_sequence=getattr(moving_pdb, "sequence", None), + reference_sequence=getattr(self, "_reference_sequence", None), ) - @staticmethod - def _compute_common_indices( - moving_cra: list[str], - reference_cra: list[str], - selection: str, - ) -> tuple[np.ndarray, np.ndarray]: - if selection not in {"ALL", "CA", "BB"}: - raise ValueError("selection must be one of: ALL, CA, BB") - - def _keep(name: str) -> bool: - if selection == "ALL": - return True - if selection == "CA": - return name.endswith("-CA") - if selection == "BB": - return ( - name.endswith("-N") or name.endswith("-CA") or name.endswith("-C") - ) - return True - - reference_lookup = { - name: idx for idx, name in enumerate(reference_cra) if _keep(name) - } - index_moving = [] - index_reference = [] - for idx, name in enumerate(moving_cra): - if not _keep(name): - continue - ref_idx = reference_lookup.get(name) - if ref_idx is not None: - index_moving.append(idx) - index_reference.append(ref_idx) - - if not index_moving: - raise ValueError("No overlapping atoms found between moving and reference") - - return np.array(index_moving), np.array(index_reference) - def compute( self, coordinates: torch.Tensor, @@ -135,3 +106,56 @@ def compute( rmse = torch.sqrt(torch.mean(diff**2)).item() return loss, {"rmse": rmse} return loss + + def set_matching_indices( + self, + query_idx: list[int], + ref_idx: list[int], + ) -> None: + """Pre-compute matching atom indices for rollout -> reference mapping. + + Args: + query_idx: Indices into the flat atom array of the OF3 prediction. + ref_idx: Corresponding indices into ``self.reference_coordinates``. + """ + self._query_idx = query_idx + self._ref_idx = ref_idx + + def compute_from_rollout( + self, + rollout_output: dict, + batch: dict, + return_metadata: bool = False, + **kwargs: Any, + ) -> torch.Tensor | tuple[torch.Tensor, dict[str, Any]]: + """Compute MSE loss directly from OF3 diffusion rollout output. + + Indexes into ``atom_positions_predicted`` with pre-stored query/ref + indices (set via :meth:`set_matching_indices`), Kabsch-aligns, and + returns the MSE. + """ + if not hasattr(self, "_query_idx") or self._query_idx is None: + raise ValueError( + "Call set_matching_indices(query_idx, ref_idx) before " + "compute_from_rollout" + ) + + xl_pred = rollout_output["atom_positions_predicted"] # [1, 1, N_atom, 3] + pred_coords = xl_pred[0, 0, self._query_idx] # [N_match, 3] + ref_coords = self.reference_coordinates[self._ref_idx].to(pred_coords.device) + + if self.align: + from LossLab.utils.geometry import kabsch_align + + pred_coords = kabsch_align(pred_coords, ref_coords) + + diff = pred_coords - ref_coords + if self.reduction == "sum": + loss = torch.sum(diff**2) + else: + loss = torch.mean(diff**2) + + if return_metadata: + rmse = torch.sqrt(torch.mean(diff**2)).item() + return loss, {"rmse": rmse} + return loss diff --git a/LossLab/utils/__init__.py b/LossLab/utils/__init__.py index b2727e4..00f353b 100644 --- a/LossLab/utils/__init__.py +++ b/LossLab/utils/__init__.py @@ -10,13 +10,13 @@ from LossLab.utils.map_utils import apply_mask, create_spherical_mask, normalize_map __all__ = [ - "gpu_memory_tracked", - "timed", + "apply_mask", "cached_property", - "validate_shapes", - "kabsch_align", "compute_rmsd", - "normalize_map", - "apply_mask", "create_spherical_mask", + "gpu_memory_tracked", + "kabsch_align", + "normalize_map", + "timed", + "validate_shapes", ] diff --git a/LossLab/utils/alignment.py b/LossLab/utils/alignment.py new file mode 100644 index 0000000..5216808 --- /dev/null +++ b/LossLab/utils/alignment.py @@ -0,0 +1,254 @@ +"""OF3/ROCKET atom extraction and alignment helpers.""" + +from __future__ import annotations + +import logging + +import numpy as np +import torch + +try: + from openfold3.core.data.resources.token_atom_constants import ( + TOKEN_NAME_TO_ATOM_NAMES, + TOKEN_TYPES_WITH_GAP, + ) +except ModuleNotFoundError: + from openfold3.core.np.token_atom_constants import ( + TOKEN_NAME_TO_ATOM_NAMES, + TOKEN_TYPES_WITH_GAP, + ) +from rocket import coordinates as rk_coordinates +from rocket import utils as rk_utils + +logger = logging.getLogger(__name__) + + +# ------------------------------------------------------------------ +# OF3 token/atom helpers +# ------------------------------------------------------------------ + + +def _restype_to_token_idx(restype: np.ndarray, vocab_size: int) -> np.ndarray: + """Convert restype array to 1-D int64 token indices.""" + rt = rk_utils.assert_numpy(restype) + + while rt.ndim > 2 and rt.shape[0] == 1: + rt = rt[0] + + if rt.ndim == 1: + idx = rt.astype(np.int64) + elif rt.ndim == 2 and rt.shape[-1] == vocab_size: + idx = rt.argmax(-1).astype(np.int64) + else: + raise ValueError(f"Unexpected restype shape after squeeze: {rt.shape}") + + return np.clip(idx.reshape(-1), 0, vocab_size - 1) + + +def get_res_names(restype: np.ndarray) -> np.ndarray: + """Get 3-letter residue/token names from unified token set.""" + token_names = np.asarray(TOKEN_TYPES_WITH_GAP, dtype=object) + token_idx = _restype_to_token_idx(restype, len(token_names)) + return token_names[token_idx] + + +# ------------------------------------------------------------------ +# Atom extraction from OF3 rollout +# ------------------------------------------------------------------ + + +def extract_allatoms( + outputs: dict, feats: dict, cra_name_sfc: list +) -> tuple[torch.Tensor, torch.Tensor]: + """Extract all-atom positions from OF3 rollout in SFC order. + + Returns (positions [N, 3], plddt [N]) with gradients preserved. + """ + token_names = np.asarray(TOKEN_TYPES_WITH_GAP, dtype=object) + + token_idx = _restype_to_token_idx(feats["restype"], len(token_names)) + res_names = token_names[token_idx] + n_res = int(res_names.shape[0]) + chain_resid = np.asarray([f"A-{i}-" for i in range(n_res)], dtype=object) + + atom_pos_pred = outputs["atom_positions_predicted"][0].squeeze(0) + if atom_pos_pred.ndim != 2 or atom_pos_pred.shape[-1] != 3: + raise ValueError( + f"Unexpected atom_positions_predicted shape: {tuple(atom_pos_pred.shape)}" + ) + n_atom = int(atom_pos_pred.shape[0]) + + atom_mask = rk_utils.assert_numpy(feats["atom_mask"][0]).reshape(-1) + if atom_mask.shape[0] != n_atom: + raise ValueError(f"atom_mask {atom_mask.shape[0]} != n_atom {n_atom}") + + pl_logits = outputs["plddt_logits"][0].squeeze(0).to(torch.float32) + probs = torch.softmax(pl_logits, dim=-1) + bin_centers = torch.linspace(0.0, 1.0, steps=probs.shape[-1], device=probs.device) + plddt_atom = ((probs * bin_centers).sum(dim=-1) * 100.0).detach().cpu().numpy() + + cra_names: list[str] = [] + atom_positions: list[torch.Tensor] = [] + plddts: list[float] = [] + + global_atom_index = 0 + for i in range(n_res): + rn = res_names[i] + resname = rn.decode() if isinstance(rn, (bytes, np.bytes_)) else str(rn) + atoms = TOKEN_NAME_TO_ATOM_NAMES.get(resname, []) + for aname in atoms: + if global_atom_index >= n_atom: + raise IndexError("global_atom_index ran past atom_positions_predicted") + if atom_mask[global_atom_index] == 0: + global_atom_index += 1 + continue + + cra_names.append(f"{chain_resid[i]}{resname}-{aname}") + atom_positions.append(atom_pos_pred[global_atom_index].to(torch.float32)) + plddts.append(float(plddt_atom[global_atom_index])) + global_atom_index += 1 + + if not atom_positions: + raise RuntimeError("No atoms extracted (empty atom_positions).") + + positions_atom = torch.stack(atom_positions, dim=0) + plddt_atom_t = torch.tensor( + plddts, dtype=torch.float32, device=positions_atom.device + ) + + idx_map = {c: k for k, c in enumerate(cra_names)} + missing = [c for c in cra_name_sfc if c not in idx_map] + if missing: + raise AssertionError( + f"Topology mismatch; missing {len(missing)} CRA " + f"(showing 10): {missing[:10]}" + ) + + reorder_index = [idx_map[c] for c in cra_name_sfc] + return positions_atom[reorder_index], plddt_atom_t[reorder_index] + + +# ------------------------------------------------------------------ +# Alignment +# ------------------------------------------------------------------ + + +def position_alignment( + rollout_output, + batch, + best_pos, + exclude_res, + cra_name, + domain_segs=None, + reference_bfactor=None, +): + """Kabsch-align OF3 rollout coords to reference. + + Returns (aligned_xyz, plddts, pseudo_Bs). + """ + coords, plddts = extract_allatoms(rollout_output, batch, cra_name) + pseudo_bs = rk_utils.plddt2pseudoB_pt(plddts) + + if reference_bfactor is None: + pseudo_b_np = rk_utils.assert_numpy(pseudo_bs) + cutoff1 = np.quantile(pseudo_b_np, 0.3) + cutoff2 = cutoff1 * 1.5 + weights = rk_utils.weighting(pseudo_b_np, cutoff1, cutoff2) + else: + assert reference_bfactor.shape == pseudo_bs.shape, ( + "reference_bfactor must match model bfactor shape" + ) + ref_b_np = rk_utils.assert_numpy(reference_bfactor) + cutoff1 = np.quantile(ref_b_np, 0.3) + cutoff2 = cutoff1 * 1.5 + weights = rk_utils.weighting(ref_b_np, cutoff1, cutoff2) + + aligned_xyz = rk_coordinates.iterative_kabsch_alignment( + coords, + best_pos, + cra_name, + weights=weights, + exclude_res=exclude_res, + domain_segs=domain_segs, + ) + mask = cra_exclude_mask(cra_name, exclude_res=exclude_res) + + rmsd_pre = weighted_rmsd(coords, best_pos, w=weights, mask=mask) + rmsd_post = weighted_rmsd(aligned_xyz, best_pos, w=weights, mask=mask) + + logger.debug( + "RMSD pre-align: %.3f post-align: %.3f", + rmsd_pre, + rmsd_post, + ) + if rmsd_post > rmsd_pre + 1e-3: + logger.warning( + "RMSD increased after alignment — check ordering / weights / mask." + ) + + return aligned_xyz, plddts, pseudo_bs.detach() + + +# ------------------------------------------------------------------ +# RMSD and masking helpers +# ------------------------------------------------------------------ + + +def _to_numpy(x): + if isinstance(x, np.ndarray): + return x + if torch.is_tensor(x): + return x.detach().cpu().numpy() + return np.asarray(x) + + +def weighted_rmsd(xyz_a, xyz_b, w=None, mask=None): + """Weighted RMSD between two [N,3] coordinate arrays.""" + a = _to_numpy(xyz_a).astype(np.float64, copy=False) + b = _to_numpy(xyz_b).astype(np.float64, copy=False) + assert a.shape == b.shape and a.ndim == 2 and a.shape[1] == 3 + + finite = np.isfinite(a).all(-1) & np.isfinite(b).all(-1) + if mask is not None: + finite = mask & finite + + if w is None: + diff2 = ((a[finite] - b[finite]) ** 2).sum(-1) + return float(np.sqrt(diff2.mean())) + + ww = _to_numpy(w).reshape(-1).astype(np.float64, copy=False) + assert ww.shape[0] == a.shape[0] + ww = ww[finite] + ww_sum = ww.sum() + if ww_sum <= 0: + raise ValueError(f"Non-positive weight sum after masking: {ww_sum}") + ww = ww / ww_sum + diff2 = ((a[finite] - b[finite]) ** 2).sum(-1) + return float(np.sqrt((ww * diff2).sum())) + + +def cra_exclude_mask(cra_name, exclude_res=None): + """Build boolean keep-mask from CRA names and exclusion list.""" + if not exclude_res: + return None + ex = {int(r) for r in exclude_res} + keep = np.ones(len(cra_name), dtype=bool) + for i, cra in enumerate(cra_name): + parts = str(cra).split("-") + if len(parts) >= 2: + try: + resid = int(parts[1]) + if resid in ex: + keep[i] = False + except ValueError: + pass + return keep + + +__all__ = [ + "cra_exclude_mask", + "extract_allatoms", + "get_res_names", + "position_alignment", + "weighted_rmsd", +] diff --git a/LossLab/utils/geometry.py b/LossLab/utils/geometry.py index d7ec227..a66223c 100644 --- a/LossLab/utils/geometry.py +++ b/LossLab/utils/geometry.py @@ -2,8 +2,6 @@ from __future__ import annotations -import contextlib - import numpy as np import torch @@ -14,42 +12,27 @@ def _as_numpy(x): return np.asarray(x) -def weighted_kabsch( - P, - Q, - weights: np.ndarray | None = None, - *, - torch_backend: bool = False, -): - """ - Weighted Kabsch alignment that computes the optimal rotation and translation - aligning P -> Q according to given non-negative weights. - - Parameters - - P, Q : (N,3) arrays (numpy or torch). P are source points, Q are target. - - weights : (N,) non-negative weights. If None, uniform weights are used. - - torch_backend : if True, use torch operations and return torch tensors. - - Returns - - R : (3,3) rotation matrix - - t : (3,) translation vector (applied as: aligned = (P - cP) @ R + cQ) - - P_aligned : (N,3) aligned coordinates in same array type as chosen backend +def weighted_kabsch(P, Q, weights=None, *, torch_backend=False): + """Weighted Kabsch alignment: find R, t minimising sum w_i ||(P-cP)R + cQ - Q||^2. + + Row-vector convention: aligned = (P - cP) @ R + cQ. """ if torch_backend: if not isinstance(P, torch.Tensor): P = torch.as_tensor(P, dtype=torch.float32) if not isinstance(Q, torch.Tensor): - Q = torch.as_tensor(Q, dtype=torch.float32) + Q = torch.as_tensor(Q, dtype=torch.float32, device=P.device) + dev = P.device + n = P.shape[0] - n_points = P.shape[0] if weights is None: - w = torch.ones(n_points, device=dev, dtype=torch.float32) + w = torch.ones(n, device=dev, dtype=torch.float32) else: w = torch.as_tensor(weights, device=dev, dtype=torch.float32) wsum = w.sum().clamp_min(1e-8) - wn = (w / wsum).view(n_points, 1) + wn = (w / wsum).view(n, 1) cP = (wn * P).sum(dim=0, keepdim=True) cQ = (wn * Q).sum(dim=0, keepdim=True) @@ -57,35 +40,36 @@ def weighted_kabsch( X = P - cP Y = Q - cQ - H = (X * wn).T @ Y - ac = ( - torch.cuda.amp.autocast(enabled=False) - if dev.type == "cuda" - else contextlib.nullcontext() - ) - with ac, torch.no_grad(): - H32 = H.to(torch.float32) - U, _, Vt = torch.linalg.svd(H32, full_matrices=False) - detVU = torch.linalg.det(U @ Vt) - D = torch.eye(3, dtype=torch.float32, device=dev) - if detVU < 0: - D[2, 2] = -1.0 - R = U @ D @ Vt + H = (X * wn).T @ Y # X^T W Y + + # SVD: H = U S Vh (Vh = V^T) + H32 = H.to(torch.float32) + U, _, Vh = torch.linalg.svd(H32, full_matrices=False) + + # Row-vector Kabsch: R = U D Vh + d = torch.linalg.det(U @ Vh) + D = torch.eye(3, device=dev, dtype=torch.float32) + if d < 0: + D[2, 2] = -1.0 + + R = (U @ D @ Vh).to(P.dtype) P_aligned = (X @ R) + cQ t = cQ.view(3) - (cP.view(3) @ R) return R, t, P_aligned + # ---- numpy path ---- Pn = _as_numpy(P).astype(np.float64) Qn = _as_numpy(Q).astype(np.float64) - n_points = Pn.shape[0] + n = Pn.shape[0] + if weights is None: - w = np.ones((n_points,), dtype=np.float64) + w = np.ones((n,), dtype=np.float64) else: w = np.asarray(weights, dtype=np.float64) wsum = max(w.sum(), 1e-8) - wn = (w / wsum).reshape(n_points, 1) + wn = (w / wsum).reshape(n, 1) cP = (wn * Pn).sum(axis=0, keepdims=True) cQ = (wn * Qn).sum(axis=0, keepdims=True) @@ -94,13 +78,14 @@ def weighted_kabsch( Y = Qn - cQ H = (X * wn).T @ Y - U, _, Vt = np.linalg.svd(H, full_matrices=False) - detVU = np.linalg.det(U @ Vt) - D = np.eye(3) - if detVU < 0: + U, _, Vh = np.linalg.svd(H, full_matrices=False) # H = U S Vh + + # Row-vector Kabsch: R = U D Vh + D = np.eye(3, dtype=np.float64) + if np.linalg.det(U @ Vh) < 0: D[2, 2] = -1.0 - R = U @ D @ Vt + R = U @ D @ Vh P_aligned = (X @ R) + cQ t = cQ.reshape(3) - (cP.reshape(3) @ R) return R, t, P_aligned diff --git a/LossLab/utils/sequence.py b/LossLab/utils/sequence.py new file mode 100644 index 0000000..56b39ba --- /dev/null +++ b/LossLab/utils/sequence.py @@ -0,0 +1,171 @@ +"""Smith-Waterman based common-index discovery for CRA name lists.""" + +from __future__ import annotations + +import re + +import numpy as np + +# 3-letter -> 1-letter mapping for standard amino acids +_AA3TO1 = { + "ALA": "A", + "CYS": "C", + "ASP": "D", + "GLU": "E", + "PHE": "F", + "GLY": "G", + "HIS": "H", + "ILE": "I", + "LYS": "K", + "LEU": "L", + "MET": "M", + "ASN": "N", + "PRO": "P", + "GLN": "Q", + "ARG": "R", + "SER": "S", + "THR": "T", + "VAL": "V", + "TRP": "W", + "TYR": "Y", +} + + +def _get_identical_indices(aligned_a: str, aligned_b: str): + """Find indices of identical residues in a pairwise alignment.""" + ind_a: list[int] = [] + ind_b: list[int] = [] + ai = 0 + bi = 0 + for a, b in zip(aligned_a, aligned_b, strict=False): + if a == "-": + bi += 1 + continue + if b == "-": + ai += 1 + continue + if a == b: + ind_a.append(ai) + ind_b.append(bi) + ai += 1 + bi += 1 + return np.array(ind_a), np.array(ind_b) + + +def _get_pattern_index(str_list, pattern: str): + """Return first index in str_list matching regex pattern.""" + for i, s in enumerate(str_list): + if re.match(pattern, str(s)): + return i + return None + + +def _sequence_from_cra(cra_name: list[str]) -> str: + """Extract 1-letter sequence from CRA name list (CA atoms only).""" + seq_parts: list[tuple[int, str]] = [] + seen: set[int] = set() + for name in cra_name: + parts = str(name).split("-") + if len(parts) < 4: + continue + if parts[-1] != "CA": + continue + try: + resid = int(parts[1]) + except ValueError: + continue + if resid in seen: + continue + seen.add(resid) + resname = parts[2] + one_letter = _AA3TO1.get(resname, "X") + seq_parts.append((resid, one_letter)) + seq_parts.sort() + return "".join(c for _, c in seq_parts) + + +def _atom_suffixes_for_selection(selection: str) -> list[str]: + """Return atom name suffixes for the given selection mode.""" + sel = selection.upper() + if sel == "CA": + return ["CA"] + if sel == "BB": + return ["N", "CA", "C"] + if sel == "ALL": + return [".*"] + raise ValueError(f"selection must be one of: ALL, CA, BB; got {sel}") + + +def compute_common_indices( + moving_cra: list[str], + reference_cra: list[str], + selection: str = "CA", + moving_sequence: str | None = None, + reference_sequence: str | None = None, +) -> tuple[np.ndarray, np.ndarray]: + """Find matching atom indices using Smith-Waterman alignment. + + Uses ``skbio.alignment.StripedSmithWaterman`` to handle sequences + with missing residues, truncations, or insertions. + + Parameters + ---------- + moving_cra, reference_cra : list[str] + CRA name lists (e.g. ``['A-0-ALA-CA', ...]``). + selection : {'ALL', 'CA', 'BB'} + Which atom types to match. + moving_sequence, reference_sequence : str, optional + 1-letter amino acid sequences. If not provided, they are + inferred from ``cra_name`` entries. + + Returns + ------- + index_moving, index_reference : np.ndarray + Matched atom indices into the respective CRA lists. + """ + from skbio.alignment import StripedSmithWaterman + + if moving_sequence is None: + moving_sequence = _sequence_from_cra(moving_cra) + if reference_sequence is None: + reference_sequence = _sequence_from_cra(reference_cra) + + query = StripedSmithWaterman(reference_sequence) + alignment = query(moving_sequence) + + sub_ref = np.arange(alignment.query_begin, alignment.query_end + 1) + sub_mov = np.arange( + alignment.target_begin, + alignment.target_end_optimal + 1, + ) + subsub_ref, subsub_mov = _get_identical_indices( + alignment.aligned_query_sequence, + alignment.aligned_target_sequence, + ) + common_ref_res = sub_ref[subsub_ref] + common_mov_res = sub_mov[subsub_mov] + + atom_suffixes = _atom_suffixes_for_selection(selection) + + index_moving: list[int] = [] + index_reference: list[int] = [] + + for res_ref, res_mov in zip(common_ref_res, common_mov_res, strict=False): + for suffix in atom_suffixes: + pattern_ref = rf".*-{res_ref}-.*-{suffix}$" + pattern_mov = rf".*-{res_mov}-.*-{suffix}$" + idx_ref = _get_pattern_index(reference_cra, pattern_ref) + idx_mov = _get_pattern_index(moving_cra, pattern_mov) + if idx_ref is not None and idx_mov is not None: + index_reference.append(idx_ref) + index_moving.append(idx_mov) + + if not index_moving: + raise ValueError("No overlapping atoms found between moving and reference") + + return np.array(index_moving), np.array(index_reference) + + +__all__ = [ + "compute_common_indices", +]