Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion ci/tests/test_c2s/test_c2s.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,11 @@ def test_empty_dataset(self, cell2sen_model, sample_anndata):
cell2sen_model.get_embeddings(empty)

def test_attention_shapes(self, cell2sen_model, processed_dataset_basic):
emb, attn = cell2sen_model.get_embeddings(processed_dataset_basic, output_attentions=True)
emb, attn, gene_order = cell2sen_model.get_embeddings(processed_dataset_basic, output_attentions=True)
assert isinstance(attn, list)
assert isinstance(gene_order, list)
assert len(gene_order) == len(processed_dataset_basic)
assert all(isinstance(gene_list, list) for gene_list in gene_order)
assert len(attn) == len(processed_dataset_basic)
for sample_attn in attn:
# (num_heads, num_genes, num_genes)
Expand Down
4 changes: 2 additions & 2 deletions examples/notebooks/Cell2Sen-Tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": null,
"id": "2d5ac9f3",
"metadata": {},
"outputs": [
Expand All @@ -219,7 +219,7 @@
"\n",
"embeddings = cell2sen_model.get_embeddings(processes_dataset)\n",
"\n",
"# embeddings, attentions = cell2sen_model.get_embeddings(processes_dataset, output_attentions=True)\n",
"# embeddings, attentions, gene_order = cell2sen_model.get_embeddings(processes_dataset, output_attentions=True)\n",
"\n",
"print(embeddings.shape)\n"
]
Expand Down
93 changes: 48 additions & 45 deletions helical/models/c2s/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,42 +162,44 @@ def _gene_ids_from_offsets(prompt, cell_sentence, offsets):
break
return gene_ids


@staticmethod
def _aggregate_token_to_word_attention(attn, word_ids):
"""
Aggregate a token-level attention matrix to word-level.

Parameters
----------
attn : np.ndarray
Token-level attention of shape (num_heads, seq_len, seq_len).
word_ids : list[int | None]
Word/gene id for each token position (None for non-gene tokens).

Returns
-------
word_attn : np.ndarray
Word-level attention of shape (num_heads, num_words, num_words).
Aggregate token-level attention to word-level.
Works with both numpy arrays (CPU) and torch tensors (GPU).
"""
use_torch = isinstance(attn, torch.Tensor)

# Build word_to_tokens mapping
word_to_tokens = {}
for tok_idx, wid in enumerate(word_ids):
if wid is not None:
word_to_tokens.setdefault(wid, []).append(tok_idx)

num_words = len(word_to_tokens)
if num_words == 0:
LOGGER.warning("No words found in attention map. Returning empty array.")
return torch.zeros((attn.shape[0], 0, 0), dtype=attn.dtype, device=attn.device)

sorted_word_ids = sorted(word_to_tokens.keys())
num_heads = attn.shape[0]
num_heads, seq_len, _ = attn.shape


word_attn = np.zeros((num_heads, num_words, num_words), dtype=attn.dtype)
# GPU path - much faster
W = torch.zeros((num_words, seq_len), dtype=attn.dtype, device=attn.device)
V = torch.zeros((num_words, seq_len), dtype=attn.dtype, device=attn.device)

for wi, wid in enumerate(sorted_word_ids):
token_indices = word_to_tokens[wid]
W[wi, token_indices] = 1.0 / len(token_indices)
V[wi, token_indices] = 1.0

temp = torch.einsum('wt,htk->hwk', W, attn)
word_attn = torch.einsum('hwk,vk->hwv', temp, V)
return word_attn.float().cpu().numpy() # .float() converts bfloat16->float32

for wi, src_wid in enumerate(sorted_word_ids):
src_tokens = word_to_tokens[src_wid]
for wj, tgt_wid in enumerate(sorted_word_ids):
tgt_tokens = word_to_tokens[tgt_wid]
block = attn[:, src_tokens, :][:, :, tgt_tokens] # (H, |src|, |tgt|)
word_attn[:, wi, wj] = block.sum(axis=2).mean(axis=1)

return word_attn

def process_data(
self,
Expand Down Expand Up @@ -359,6 +361,7 @@ def get_embeddings(
emb_layer : int, optional
Which layer to extract attention from (default: -1, i.e. last layer).
Only used when output_attentions=True.
Only one layer of attention can be returned at a time.

Returns:
--------
Expand All @@ -367,6 +370,9 @@ def get_embeddings(
attn_list : list, optional
If output_attentions=True, a list of gene-level attention arrays,
one per sample, each of shape (num_heads, num_genes, num_genes).
gene_names_list : list, optional
If output_attentions=True, a list of gene name lists, one list per sample,
e.g. [['geneA', 'geneB', ...], ['geneX', 'geneY', ...], ...]. This is used to attention values to specific genes.
"""

LOGGER.info("Extracting embeddings from dataset")
Expand All @@ -380,7 +386,7 @@ def get_embeddings(
organisms_list = dataset['organism']

all_embeddings = []
all_attentions = []
all_attentions = [[]] # Single list for the one layer we process

progress_bar = tqdm(total=len(sentences_list), desc="Processing embeddings")
for i in range(0, len(sentences_list), self.batch_size):
Expand Down Expand Up @@ -444,25 +450,21 @@ def get_embeddings(
if output_attentions:
# outputs.attentions is a tuple of tensors, one per layer
# Each tensor has shape (batch_size, num_heads, seq_length, seq_length)
# Initialize all_attentions_per_layer on first batch
if len(all_attentions) == 0:
num_layers = len(outputs.attentions)
all_attentions = [[] for _ in range(num_layers)]
# Only process the selected layer (emb_layer)


# Aggregate token-level attention to gene-level per sample
batch_size_actual = inputs['input_ids'].shape[0]
for layer_idx, attn in enumerate(outputs.attentions):
attn_np = attn.float().cpu().numpy() # (B, H, L, L)
word_attns = []
for b in range(batch_size_actual):
offsets_b = batch_offsets[b].tolist()
gene_ids = self._gene_ids_from_offsets(
prompts[b], batch_sentences[b], offsets_b
)
word_attns.append(
self._aggregate_token_to_word_attention(attn_np[b], gene_ids)
)
all_attentions[layer_idx].append(word_attns)
attn = outputs.attentions[emb_layer] #allow returning only one layer of attention
word_attns = []
for b in range(batch_size_actual):
offsets_b = batch_offsets[b].tolist()
gene_ids = self._gene_ids_from_offsets(
prompts[b], batch_sentences[b], offsets_b
)
word_attns.append(
self._aggregate_token_to_word_attention(attn[b], gene_ids)
)
all_attentions[0].append(word_attns)
del outputs

all_embeddings.append(batch_embeddings.float().cpu().numpy())
Expand All @@ -474,14 +476,15 @@ def get_embeddings(
# Restore the original attention implementation
self.model.config._attn_implementation = self.attn_implementation

# Flatten per-batch lists into a single list per layer
# Flatten per-batch lists for the selected layer
stacked_attentions = [
[arr for batch_list in all_attentions[layer_idx] for arr in batch_list]
for layer_idx in range(len(all_attentions))
[arr for batch_list in all_attentions[0] for arr in batch_list]
]
# Return only the selected layer as a flat list (like Geneformer)
attn_list = stacked_attentions[emb_layer]
return np.concatenate(all_embeddings, axis=0), attn_list
attn_list = stacked_attentions[0]
# Gene names per sample from cell sentences
gene_names_list = [sentence.split() for sentence in sentences_list]
return np.concatenate(all_embeddings, axis=0), attn_list, gene_names_list
else:
return np.concatenate(all_embeddings, axis=0)

Expand Down
4 changes: 2 additions & 2 deletions helical/models/transcriptformer/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@ def load_data(file_path):
return None, False


def load_gene_features(adata, gene_col_name):
def load_gene_features(adata, gene_col_name, species: str = "hsapiens"):
"""Load ensembl ids from adata object."""
if gene_col_name == "index":
adata.var["index"] = adata.var_names
elif gene_col_name not in adata.var.columns:
message = f"Gene column '{gene_col_name}' not found in adata.var.columns. Available columns: {adata.var.columns}. Modify config accordingly."
logging.error(message)
raise ValueError(message)
adata = map_gene_symbols_to_ensembl_ids(adata, gene_names=gene_col_name)
adata = map_gene_symbols_to_ensembl_ids(adata, gene_names=gene_col_name, species=species)
gene_names = np.array(list(adata.var["ensembl_id"].values))
return gene_names, True

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "helical"
version = "1.6.0"
version = "1.7.0"
authors = [
{ name="Helical Team", email="support@helical-ai.com" },
]
Expand Down
Loading