Skip to content

Commit eae5812

Browse files
committed
updated test for ssemb uses MSA from original papers and matches scores
1 parent cc65a26 commit eae5812

5 files changed

Lines changed: 54145 additions & 3 deletions

File tree

aide_predict/utils/msa.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,3 +478,132 @@ def compute_conservation(self, msa, normalize=True, gap_treatment='exclude', gap
478478

479479
logger.info(f"Conservation calculation complete: min={conservation.min():.4f}, max={conservation.max():.4f}, mean={conservation.mean():.4f}")
480480
return conservation
481+
482+
483+
def remove_gappy_columns(self,
484+
msa: ProteinSequences,
485+
gap_threshold: Optional[float] = None,
486+
focus_seq_id: Optional[str] = None) -> ProteinSequences:
487+
"""
488+
Remove columns from the MSA that exceed the gap threshold.
489+
490+
All sequences are retained, but columns with gap fraction above the threshold
491+
are completely removed from the alignment.
492+
493+
Args:
494+
msa (ProteinSequences): The input multiple sequence alignment.
495+
gap_threshold (Optional[float]): Maximum allowed fraction of gaps per column.
496+
If None, uses self.threshold_focus_cols_frac_gaps.
497+
focus_seq_id (Optional[str]): If provided, only consider gaps relative to
498+
non-gap positions in the focus sequence. If None, consider all positions.
499+
500+
Returns:
501+
ProteinSequences: New MSA with high-gap columns removed.
502+
503+
Raises:
504+
ValueError: If the input MSA is not aligned.
505+
ValueError: If gap_threshold is not between 0 and 1.
506+
ValueError: If focus_seq_id is provided but not found in MSA.
507+
"""
508+
if not msa.aligned:
509+
raise ValueError("Input MSA must be aligned")
510+
511+
if gap_threshold is None:
512+
gap_threshold = self.threshold_focus_cols_frac_gaps
513+
514+
if not 0 <= gap_threshold <= 1:
515+
raise ValueError("gap_threshold must be between 0 and 1")
516+
517+
logger.info(f"Removing columns with gap fraction > {gap_threshold}")
518+
logger.debug(f"Input MSA: {len(msa)} sequences, {msa.width} columns")
519+
520+
# Get MSA as array for easier manipulation
521+
msa_array = msa.as_array()
522+
523+
# If focus sequence is specified, first filter to focus sequence non-gap positions
524+
if focus_seq_id is not None:
525+
if focus_seq_id not in msa.id_mapping:
526+
raise ValueError(f"Focus sequence ID '{focus_seq_id}' not found in MSA")
527+
528+
focus_seq = msa[focus_seq_id]
529+
focus_seq_array = np.array(list(str(focus_seq)))
530+
531+
# Only consider positions that are not gaps in the focus sequence
532+
focus_positions = focus_seq_array != '-'
533+
msa_array_filtered = msa_array[:, focus_positions]
534+
535+
logger.debug(f"Focus sequence '{focus_seq_id}' has {np.sum(focus_positions)} non-gap positions")
536+
else:
537+
msa_array_filtered = msa_array
538+
focus_positions = np.ones(msa.width, dtype=bool)
539+
540+
# Calculate gap fraction for each column
541+
gap_fractions = np.mean(msa_array_filtered == '-', axis=0)
542+
543+
# Identify columns that pass the threshold
544+
columns_to_keep_filtered = gap_fractions <= gap_threshold
545+
546+
logger.debug(f"Gap fractions: min={gap_fractions.min():.3f}, "
547+
f"max={gap_fractions.max():.3f}, mean={gap_fractions.mean():.3f}")
548+
logger.debug(f"Columns passing threshold: {np.sum(columns_to_keep_filtered)}/{len(columns_to_keep_filtered)}")
549+
550+
# Map back to original column indices if focus sequence was used
551+
if focus_seq_id is not None:
552+
columns_to_keep = np.zeros(msa.width, dtype=bool)
553+
columns_to_keep[focus_positions] = columns_to_keep_filtered
554+
else:
555+
columns_to_keep = columns_to_keep_filtered
556+
557+
# Check if any columns remain
558+
if np.sum(columns_to_keep) == 0:
559+
raise ValueError(f"No columns pass the gap threshold of {gap_threshold}. "
560+
f"Consider increasing the threshold.")
561+
562+
# Filter the MSA array
563+
filtered_msa_array = msa_array[:, columns_to_keep]
564+
565+
# Create new ProteinSequences with filtered columns
566+
filtered_sequences = []
567+
valid_sequence_indices = []
568+
569+
for i, original_seq in enumerate(msa):
570+
filtered_seq_str = ''.join(filtered_msa_array[i])
571+
572+
# Check if sequence is all gaps after column removal
573+
non_gap_chars = [char for char in filtered_seq_str if char not in GAP_CHARACTERS]
574+
575+
if len(non_gap_chars) == 0:
576+
logger.debug(f"Removing sequence '{original_seq.id}' - all gaps after column filtering")
577+
continue # Skip this sequence
578+
579+
# Create new ProteinSequence preserving metadata
580+
filtered_seq = ProteinSequence(
581+
filtered_seq_str,
582+
id=original_seq.id,
583+
structure=original_seq.structure
584+
)
585+
586+
# Preserve MSA reference if it exists
587+
if original_seq.has_msa:
588+
filtered_seq.msa = original_seq.msa
589+
590+
filtered_sequences.append(filtered_seq)
591+
valid_sequence_indices.append(i)
592+
593+
# Check if any sequences remain
594+
if len(filtered_sequences) == 0:
595+
raise ValueError("No sequences remain after removing all-gap sequences. "
596+
f"Consider relaxing the gap threshold (current: {gap_threshold})")
597+
598+
# Create new ProteinSequences object
599+
filtered_msa = ProteinSequences(filtered_sequences)
600+
601+
# Preserve weights for valid sequences only
602+
if hasattr(msa, 'weights') and msa.weights is not None:
603+
filtered_msa.weights = msa.weights[valid_sequence_indices]
604+
605+
removed_sequences = len(msa) - len(filtered_msa)
606+
logger.info(f"Filtered MSA: {len(filtered_msa)} sequences, {filtered_msa.width} columns "
607+
f"(removed {msa.width - filtered_msa.width} columns, {removed_sequences} all-gap sequences)")
608+
609+
return filtered_msa

0 commit comments

Comments
 (0)