Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions ci/tests/test_c2s/test_c2s.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,12 @@ def test_empty_dataset(self, cell2sen_model, sample_anndata):

def test_attention_shapes(self, cell2sen_model, processed_dataset_basic):
emb, attn = cell2sen_model.get_embeddings(processed_dataset_basic, output_attentions=True)
assert isinstance(attn, tuple)
for layer in attn:
assert layer.ndim == 4
assert layer.shape[0] == len(processed_dataset_basic)
assert layer.shape[2] == layer.shape[3]
assert isinstance(attn, list)
assert len(attn) == len(processed_dataset_basic)
for sample_attn in attn:
# (num_heads, num_genes, num_genes)
assert sample_attn.ndim == 3
assert sample_attn.shape[1] == sample_attn.shape[2]


class TestGetPerturbations:
Expand Down
147 changes: 131 additions & 16 deletions helical/models/c2s/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,87 @@ def __init__(self, configurer: Cell2SenConfig = None) -> None:

LOGGER.info("Successfully loaded model")

@staticmethod
def _gene_ids_from_offsets(prompt, cell_sentence, offsets):
"""
Build a per-token gene id list from character offsets.

Tokens that fall within a gene's character span get the gene's
index; all other tokens (prompt text, special tokens, padding)
get None.

Parameters
----------
prompt : str
The full prompt string.
cell_sentence : str
Space-separated gene names embedded in the prompt.
offsets : list[tuple[int, int]]
Per-token (char_start, char_end) from ``return_offsets_mapping=True``.

Returns
-------
gene_ids : list[int | None]
Gene index for each token, or None.
"""
cs_start = prompt.find(cell_sentence)
genes = cell_sentence.split()

# character ranges for each gene (in prompt coordinates)
gene_ranges = []
pos = cs_start
for g in genes:
gs = prompt.index(g, pos)
gene_ranges.append((gs, gs + len(g)))
pos = gs + len(g)

gene_ids = [None] * len(offsets)
for tok_idx, (ts, te) in enumerate(offsets):
if ts == te:
continue
for gi, (gs, ge) in enumerate(gene_ranges):
if ts < ge and te > gs:
gene_ids[tok_idx] = gi
break
return gene_ids

@staticmethod
def _aggregate_token_to_word_attention(attn, word_ids):
"""
Aggregate a token-level attention matrix to word-level.

Parameters
----------
attn : np.ndarray
Token-level attention of shape (num_heads, seq_len, seq_len).
word_ids : list[int | None]
Word/gene id for each token position (None for non-gene tokens).

Returns
-------
word_attn : np.ndarray
Word-level attention of shape (num_heads, num_words, num_words).
"""
word_to_tokens = {}
for tok_idx, wid in enumerate(word_ids):
if wid is not None:
word_to_tokens.setdefault(wid, []).append(tok_idx)

num_words = len(word_to_tokens)
sorted_word_ids = sorted(word_to_tokens.keys())
num_heads = attn.shape[0]

word_attn = np.zeros((num_heads, num_words, num_words), dtype=attn.dtype)

for wi, src_wid in enumerate(sorted_word_ids):
src_tokens = word_to_tokens[src_wid]
for wj, tgt_wid in enumerate(sorted_word_ids):
tgt_tokens = word_to_tokens[tgt_wid]
block = attn[:, src_tokens, :][:, :, tgt_tokens] # (H, |src|, |tgt|)
word_attn[:, wi, wj] = block.sum(axis=2).mean(axis=1)

return word_attn

def process_data(
self,
adata: anndata.AnnData,
Expand Down Expand Up @@ -258,29 +339,43 @@ def process_data(
return dataset

def get_embeddings(
self,
dataset: Dataset,
self,
dataset: Dataset,
output_attentions: bool = False,
emb_layer: int = -1,
):
"""
Extract embeddings from cell sentences in a HuggingFace Dataset using the last hidden layer of Gemma.

Parameters:
-----------
dataset : Dataset
HuggingFace Dataset with 'cell_sentence' and 'organism' fields
HuggingFace Dataset with 'cell_sentence' and 'organism' fields

output_attentions : bool, optional
Whether to output the attention maps from the model. If set to True, the attention maps will be returned along with the embeddings.
If set to False, only the embeddings will be returned. **Note**: This will increase the memory usage of the model significantly, so use it only if you need the attention maps.

emb_layer : int, optional
Which layer to extract attention from (default: -1, i.e. last layer).
Only used when output_attentions=True.

Returns:
--------
embeddings : np.ndarray
Embeddings of shape (num_sentences, hidden_size)
attn_list : list, optional
If output_attentions=True, a list of gene-level attention arrays,
one per sample, each of shape (num_heads, num_genes, num_genes).
"""

LOGGER.info("Extracting embeddings from dataset")

if output_attentions:
# SDPA and FlashAttention do not support returning attention maps;
# override to eager on the model config so all layers use it.
self.model.config._attn_implementation = "eager"

sentences_list = dataset['cell_sentence']
organisms_list = dataset['organism']

Expand Down Expand Up @@ -308,13 +403,18 @@ def get_embeddings(
return_tensors="pt",
padding=True,
truncation=False,
return_offsets_mapping=output_attentions,
# truncation=True,
# max_length=max_length
).to(self.device)
)
# offset_mapping is not a tensor; grab it before .to(device)
if output_attentions:
batch_offsets = inputs.pop("offset_mapping")
inputs = inputs.to(self.device)

with torch.no_grad():
outputs = self.model(
**inputs,
**inputs,
output_hidden_states=True,
output_attentions=output_attentions
)
Expand Down Expand Up @@ -348,10 +448,21 @@ def get_embeddings(
if len(all_attentions) == 0:
num_layers = len(outputs.attentions)
all_attentions = [[] for _ in range(num_layers)]

# Append attention maps from each layer to the corresponding list

# Aggregate token-level attention to gene-level per sample
batch_size_actual = inputs['input_ids'].shape[0]
for layer_idx, attn in enumerate(outputs.attentions):
all_attentions[layer_idx].append(attn.float().cpu().numpy())
attn_np = attn.float().cpu().numpy() # (B, H, L, L)
word_attns = []
for b in range(batch_size_actual):
offsets_b = batch_offsets[b].tolist()
gene_ids = self._gene_ids_from_offsets(
prompts[b], batch_sentences[b], offsets_b
)
word_attns.append(
self._aggregate_token_to_word_attention(attn_np[b], gene_ids)
)
all_attentions[layer_idx].append(word_attns)
del outputs

all_embeddings.append(batch_embeddings.float().cpu().numpy())
Expand All @@ -360,13 +471,17 @@ def get_embeddings(
LOGGER.info("Successfully extracted embeddings")

if output_attentions:
# Concatenate attention maps per layer across batches
# Each element in stacked_attentions has shape (total_batch_size, num_heads, seq_length, seq_length)
stacked_attentions = tuple(
np.concatenate(all_attentions[layer_idx], axis=0)
# Restore the original attention implementation
self.model.config._attn_implementation = self.attn_implementation

# Flatten per-batch lists into a single list per layer
stacked_attentions = [
[arr for batch_list in all_attentions[layer_idx] for arr in batch_list]
for layer_idx in range(len(all_attentions))
)
return np.concatenate(all_embeddings, axis=0), stacked_attentions
]
# Return only the selected layer as a flat list (like Geneformer)
attn_list = stacked_attentions[emb_layer]
return np.concatenate(all_embeddings, axis=0), attn_list
else:
return np.concatenate(all_embeddings, axis=0)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "helical"
version = "1.5.3"
version = "1.6.0"
authors = [
{ name="Helical Team", email="support@helical-ai.com" },
]
Expand Down
Loading