Skip to content

Commit 0894ccf

Browse files
committed
drafting method for any positions specific embedder to accept aligned sequences
1 parent 9829702 commit 0894ccf

4 files changed

Lines changed: 272 additions & 54 deletions

File tree

aide_predict/bespoke_models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from .embedders.saprot import SaProtEmbedding
2121
from .embedders.kmer import KmerEmbedding
2222
from .embedders.ssemb import SSEmbEmbedding
23+
from .embedders.aa_properties import AAPropertiesEmbedding
2324

2425

2526
TOOLS = [
@@ -40,4 +41,5 @@
4041
SaProtEmbedding,
4142
KmerEmbedding,
4243
SSEmbEmbedding,
44+
AAPropertiesEmbedding,
4345
]

aide_predict/bespoke_models/base.py

Lines changed: 81 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -898,7 +898,8 @@ class PositionSpecificMixin:
898898
899899
This mixin adds functionality for handling position-specific outputs from protein models.
900900
It allows selecting specific positions to analyze, pooling across positions, and
901-
flattening multi-dimensional outputs.
901+
flattening multi-dimensional outputs. It can also automatically handle aligned sequences
902+
with gaps by stripping gaps before processing and remapping embeddings back to aligned positions.
902903
903904
Attributes:
904905
positions (Optional[List[int]]): The positions to output scores for. If None, all positions are used.
@@ -908,14 +909,21 @@ class PositionSpecificMixin:
908909
- If callable: Uses the provided function for pooling
909910
- If False: No pooling is performed
910911
flatten (bool): Whether to flatten dimensions beyond the second dimension.
912+
handle_aligned (bool): If True, automatically strip gaps before processing and remap to aligned positions.
913+
gap_fill_value (float): Value to use for gap positions in aligned sequences (default 0.0).
911914
"""
912915
_per_position_capable: bool = True
913916

914-
def _init_handler(self, positions=None, pool=True, flatten=True, **kwargs):
917+
def _init_handler(self, positions=None, pool=True, flatten=True,
918+
handle_aligned=True, gap_fill_value=0.0, **kwargs):
915919
"""Initialize position-specific attributes from kwargs."""
916920
self.positions = positions
917921
self.pool = pool
918922
self.flatten = flatten
923+
self.handle_aligned = handle_aligned
924+
self.gap_fill_value = gap_fill_value
925+
# Temporary storage for alignment mapping during transform
926+
self._alignment_mapping = None
919927

920928
def _is_ragged_array(self, arr):
921929
"""Check if the input is a ragged array (list of arrays with different shapes)."""
@@ -934,9 +942,71 @@ def _pre_transform_hook(self, X):
934942
if self.positions is not None:
935943
if not (X.aligned or len(X) == 1):
936944
raise ValueError("Input sequences must be same length / aligned for position-specific output.")
945+
946+
# Handle aligned sequences if enabled
947+
if self.handle_aligned and X.has_gaps:
948+
# Store mapping in instance for post-transform hook
949+
self._alignment_mapping = X.get_alignment_mapping()
950+
# Convert mapping to have integer keys ascending from 0
951+
self._alignment_mapping = {i: m for i, m in enumerate(self._alignment_mapping.values())}
952+
X = X.with_no_gaps()
953+
954+
# Validate behavior is well-defined
955+
if self.positions is None and not self.pool:
956+
raise ValueError(
957+
"Cannot return position-specific embeddings for sequences with gaps "
958+
"unless positions are specified or pooling is enabled."
959+
)
960+
else:
961+
self._alignment_mapping = None
937962

938963
return X
939964

965+
def _remap_to_aligned_positions(self, result, mapping, positions, fill_value):
966+
"""
967+
Remap embeddings from ungapped sequences back to aligned positions.
968+
969+
Args:
970+
result: List of embeddings for ungapped sequences
971+
mapping: Dict mapping sequence index to list of aligned positions
972+
positions: List of aligned positions to extract
973+
fill_value: Value to use for gap positions
974+
975+
Returns:
976+
List of remapped embeddings with gaps represented by fill_value
977+
"""
978+
aligned_embeddings = []
979+
for i, emb in enumerate(result):
980+
seq_mapping = mapping[i]
981+
# emb shape: (1, seq_len, embedding_dim), (seq_len, embedding_dim), or (seq_len,) for 1D
982+
# Remove batch dimension if present
983+
if emb.ndim == 3 and emb.shape[0] == 1:
984+
emb = emb[0] # Now (seq_len, embedding_dim)
985+
986+
if emb.ndim == 1:
987+
# Handle 1D case (e.g., single position or pooled)
988+
emb = np.expand_dims(emb, 0)
989+
squeeze_after = True
990+
else:
991+
squeeze_after = False
992+
993+
# Now emb is (seq_len, embedding_dim)
994+
seq_len = emb.shape[0]
995+
embedding_dim = emb.shape[-1] if emb.ndim > 1 else 1
996+
aligned_emb = np.full((len(positions), embedding_dim), fill_value, dtype=emb.dtype)
997+
998+
for j, pos in enumerate(positions):
999+
if pos in seq_mapping:
1000+
aligned_pos = seq_mapping.index(pos)
1001+
if aligned_pos < seq_len:
1002+
aligned_emb[j] = emb[aligned_pos]
1003+
1004+
if squeeze_after:
1005+
aligned_emb = np.squeeze(aligned_emb, axis=-1)
1006+
1007+
aligned_embeddings.append(aligned_emb)
1008+
return aligned_embeddings
1009+
9401010
def _post_transform_hook(self, result, X):
9411011
"""
9421012
Process the model output to handle position selection, pooling, and flattening.
@@ -950,6 +1020,15 @@ def _post_transform_hook(self, result, X):
9501020
"""
9511021
if result is None or len(result) == 0:
9521022
return result
1023+
1024+
# Remap to aligned positions if we have a mapping and positions were specified
1025+
if self._alignment_mapping is not None and self.positions is not None:
1026+
result = self._remap_to_aligned_positions(
1027+
result, self._alignment_mapping, self.positions, self.gap_fill_value
1028+
)
1029+
# Clean up temporary storage
1030+
self._alignment_mapping = None
1031+
9531032
if self.pool:
9541033
# get the pool function
9551034
if isinstance(self.pool, str):
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
# aide_predict/bespoke_models/embedders/aa_properties.py
2+
'''
3+
* Author: Evan Komp
4+
* Created: 11/24/2024
5+
* Company: National Renewable Energy Lab, Bioeneergy Science and Technology
6+
* License: MIT
7+
8+
Simple amino acid property embedder for testing position-specific functionality.
9+
'''
10+
import numpy as np
11+
from typing import List, Union, Optional
12+
13+
from aide_predict.bespoke_models.base import (
14+
ProteinModelWrapper,
15+
PositionSpecificMixin,
16+
CanHandleAlignedSequencesMixin,
17+
ExpectsNoFitMixin
18+
)
19+
from aide_predict.utils.data_structures import ProteinSequences, ProteinSequence
20+
from aide_predict.utils.common import MessageBool
21+
22+
import logging
23+
logger = logging.getLogger(__name__)
24+
25+
AVAILABLE = MessageBool(True, "AAPropertiesEmbedding is always available")
26+
27+
28+
# Simple physicochemical properties for the 20 standard amino acids
29+
AA_PROPERTIES = {
30+
'A': [1.8, 0.0, 0.0], # Alanine: hydrophobicity, charge, size
31+
'C': [2.5, 0.0, 0.0], # Cysteine
32+
'D': [-3.5, -1.0, 0.0], # Aspartic acid
33+
'E': [-3.5, -1.0, 0.5], # Glutamic acid
34+
'F': [2.8, 0.0, 1.0], # Phenylalanine
35+
'G': [-0.4, 0.0, -1.0], # Glycine
36+
'H': [-3.2, 0.5, 0.5], # Histidine
37+
'I': [4.5, 0.0, 0.5], # Isoleucine
38+
'K': [-3.9, 1.0, 0.5], # Lysine
39+
'L': [3.8, 0.0, 0.5], # Leucine
40+
'M': [1.9, 0.0, 0.5], # Methionine
41+
'N': [-3.5, 0.0, 0.0], # Asparagine
42+
'P': [-1.6, 0.0, 0.0], # Proline
43+
'Q': [-3.5, 0.0, 0.5], # Glutamine
44+
'R': [-4.5, 1.0, 1.0], # Arginine
45+
'S': [-0.8, 0.0, -0.5], # Serine
46+
'T': [-0.7, 0.0, 0.0], # Threonine
47+
'V': [4.2, 0.0, 0.0], # Valine
48+
'W': [-0.9, 0.0, 1.5], # Tryptophan
49+
'Y': [-1.3, 0.0, 1.0], # Tyrosine
50+
}
51+
52+
53+
class AAPropertiesEmbedding(
54+
ExpectsNoFitMixin,
55+
PositionSpecificMixin,
56+
CanHandleAlignedSequencesMixin,
57+
ProteinModelWrapper
58+
):
59+
"""
60+
A simple amino acid property embedder for testing position-specific functionality.
61+
62+
This embedder converts each amino acid to a 3-dimensional vector based on:
63+
- Hydrophobicity (Kyte-Doolittle scale approximation)
64+
- Charge (at physiological pH)
65+
- Size (relative volume)
66+
67+
This is a simple, fast embedder that can handle aligned sequences with gaps
68+
and is useful for testing the PositionSpecificMixin functionality.
69+
70+
Attributes:
71+
positions (Optional[List[int]]): Specific positions to encode. If None, all positions are encoded.
72+
pool (bool): Whether to pool the encoded vectors across positions.
73+
flatten (bool): Whether to flatten the output array.
74+
handle_aligned (bool): Whether to handle aligned sequences with gaps.
75+
gap_fill_value (float): Value to use for gap positions.
76+
"""
77+
78+
_available = AVAILABLE
79+
80+
def __init__(
81+
self,
82+
metadata_folder: str = None,
83+
positions: Optional[List[int]] = None,
84+
flatten: bool = False,
85+
pool: bool = False,
86+
handle_aligned: bool = True,
87+
gap_fill_value: float = 0.0,
88+
wt: Optional[Union[str, ProteinSequence]] = None,
89+
**kwargs
90+
):
91+
"""
92+
Initialize the AAPropertiesEmbedding.
93+
94+
Args:
95+
metadata_folder (str): The folder where metadata is stored.
96+
positions (Optional[List[int]]): Specific positions to encode. If None, all positions are encoded.
97+
flatten (bool): Whether to flatten the output array.
98+
pool (bool): Whether to pool the encoded vectors across positions.
99+
handle_aligned (bool): Whether to handle aligned sequences with gaps.
100+
gap_fill_value (float): Value to use for gap positions.
101+
wt (Optional[Union[str, ProteinSequence]]): The wild type sequence, if any.
102+
"""
103+
super().__init__(
104+
metadata_folder=metadata_folder,
105+
wt=wt,
106+
positions=positions,
107+
pool=pool,
108+
flatten=flatten,
109+
handle_aligned=handle_aligned,
110+
gap_fill_value=gap_fill_value,
111+
**kwargs
112+
)
113+
self.embedding_dim_ = 3 # 3 properties per amino acid
114+
115+
def _fit(self, X: ProteinSequences, y: Optional[np.ndarray] = None) -> 'AAPropertiesEmbedding':
116+
"""
117+
Fit the embedder (no actual fitting needed as properties are predefined).
118+
119+
Args:
120+
X (ProteinSequences): The input protein sequences.
121+
y (Optional[np.ndarray]): Ignored. Present for API consistency.
122+
123+
Returns:
124+
AAPropertiesEmbedding: The fitted embedder.
125+
"""
126+
self.fitted_ = True
127+
return self
128+
129+
def _transform(self, X: ProteinSequences) -> List[np.ndarray]:
130+
"""
131+
Transform the protein sequences into amino acid property embeddings.
132+
133+
Args:
134+
X (ProteinSequences): The input protein sequences.
135+
136+
Returns:
137+
List[np.ndarray]: The amino acid property embeddings for the sequences.
138+
"""
139+
all_embeddings = []
140+
141+
for seq in X:
142+
seq_str = str(seq).upper()
143+
seq_len = len(seq_str)
144+
145+
# Create embedding matrix: (seq_len, 3)
146+
embedding = np.zeros((1, seq_len, 3), dtype=np.float32)
147+
148+
for i, aa in enumerate(seq_str):
149+
if aa in AA_PROPERTIES:
150+
embedding[0, i, :] = AA_PROPERTIES[aa]
151+
else:
152+
# Unknown amino acid - use zeros
153+
logger.warning(f"Unknown amino acid '{aa}' in sequence {seq.id}, using zeros")
154+
embedding[0, i, :] = [0.0, 0.0, 0.0]
155+
156+
all_embeddings.append(embedding)
157+
158+
# Return as list - PositionSpecificMixin will handle position selection, pooling, and alignment remapping
159+
return all_embeddings
160+
161+
def get_feature_names_out(self, input_features: Optional[List[str]] = None) -> List[str]:
162+
"""
163+
Get output feature names for transformation.
164+
165+
Args:
166+
input_features (Optional[List[str]]): Ignored. Present for API consistency.
167+
168+
Returns:
169+
List[str]: Output feature names.
170+
"""
171+
if not hasattr(self, 'fitted_'):
172+
raise ValueError("Model has not been fitted yet. Call fit() before using this method.")
173+
174+
positions = self.positions
175+
property_names = ['hydrophobicity', 'charge', 'size']
176+
177+
if self.pool:
178+
return [f"AAProps_{prop}" for prop in property_names]
179+
elif self.flatten:
180+
if positions is None:
181+
raise ValueError("Cannot return feature names for flattened embeddings without specifying positions")
182+
return [f"pos{p}_{prop}" for p in positions for prop in property_names]
183+
else:
184+
raise ValueError("Cannot return feature names for non-flattened non-pooled embeddings.")

aide_predict/bespoke_models/embedders/esm2.py

Lines changed: 5 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -149,20 +149,9 @@ def _transform(self, X: ProteinSequences) -> np.ndarray:
149149
raise ValueError("Cannot flatten variable length sequences without positions or pooling.")
150150
warnings.warn("Variable length sequences are being processed without positions or pooling, raw shapes will be output.")
151151

152-
mapping = None
153-
if X.has_gaps:
154-
# here we need to store a mapping such that if positions were specified we can map back to
155-
# the aligned positions
156-
mapping = X.get_alignment_mapping()
157-
# convert mapping to have integer keys ascending from 0
158-
mapping = {i: m for i, m in enumerate(mapping.values())}
159-
160-
X = X.with_no_gaps()
161-
# raise if positions were not passed - here behavior is uncertain
162-
if self.positions is None and not self.pool:
163-
raise ValueError("Cannot return position-specific embeddings for sequences with gaps unless positions are specified or pooling is on.")
164-
165-
base_index = 0
152+
# Note: gap handling is now managed by PositionSpecificMixin hooks
153+
# X will arrive here without gaps if handle_aligned=True
154+
166155
bar = tqdm.tqdm(total=len(X), desc="Computing ESM2 embeddings")
167156
for batch in X.iter_batches(self.batch_size):
168157
batch_sequences = self._prepare_sequences(batch)
@@ -181,49 +170,13 @@ def _transform(self, X: ProteinSequences) -> np.ndarray:
181170
# Remove special tokens (assuming first and last tokens are special)
182171
embeddings = [emb[1:-1] for emb in embeddings]
183172

184-
if self.positions is not None and mapping is None:
185-
# here we have fixed length so we can just use positions
186-
embeddings = [emb[self.positions] for emb in embeddings]
187-
elif self.positions is not None and mapping is not None:
188-
# here we have variable length sequences that were input as an aligned set,
189-
# the user asked for positions in the alignment
190-
aligned_embeddings = []
191-
for i, emb in enumerate(embeddings):
192-
seq_mapping = mapping[base_index + i]
193-
aligned_emb = np.zeros((len(self.positions), emb.shape[1]))
194-
for j, pos in enumerate(self.positions):
195-
if pos in seq_mapping:
196-
aligned_pos = seq_mapping.index(pos)
197-
aligned_emb[j] = emb[aligned_pos]
198-
# If pos is not in seq_mapping, it remains a zero vector
199-
aligned_embeddings.append(aligned_emb)
200-
embeddings = aligned_embeddings
201-
else:
202-
# Here positions were not specified and either have fixed length or pooling
203-
# is on
204-
pass
205-
206-
if self.pool:
207-
if self.pool == 'mean' or self.pool is True:
208-
embeddings = [emb.mean(axis=0) for emb in embeddings]
209-
elif self.pool == 'max':
210-
embeddings = [emb.max(axis=0) for emb in embeddings]
211-
elif hasattr(np, self.pool):
212-
# check if the pool is a numpy function
213-
pool_func = getattr(np, self.pool)
214-
embeddings = [pool_func(emb, axis=0) for emb in embeddings]
215-
else:
216-
raise ValueError(f"Invalid pooling method: {self.pool}")
217-
218-
# add 0th dimension
173+
# Add 0th dimension for stacking
219174
embeddings = [np.expand_dims(emb, 0) for emb in embeddings]
220175
all_embeddings.extend(embeddings)
221176

222-
base_index += len(batch)
223-
224177
bar.update(len(batch))
225178

226-
# stack along 0 dimension
179+
# Return as list - PositionSpecificMixin will handle position selection, pooling, and alignment remapping
227180
return all_embeddings
228181

229182
def get_feature_names_out(self, input_features: Optional[List[str]] = None) -> List[str]:

0 commit comments

Comments
 (0)