Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/notebooks/Tahoe-x1-Tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@
")\n",
"\n",
"print(f\"Cell embeddings shape: {cell_embeddings_attn.shape}\")\n",
"print(f\"Attention weights shape: {attention_weights.shape}\")\n",
"print(f\"Attention weights shape: {attention_weights[0].shape}\")\n",
"print(f\"\\nAttention weights dimensions: (n_cells, n_heads, seq_length, seq_length)\")"
]
},
Expand Down
2 changes: 1 addition & 1 deletion examples/run_models/configs/c2s_config.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
batch_size: 8
perturbation_column: "perturbation"
model_size: "2B"
model_size: "2B"
21 changes: 19 additions & 2 deletions examples/run_models/run_c2s.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
def run(cfg: DictConfig):

adata = ad.read_h5ad("./yolksac_human.h5ad")
n_cells = 10
n_cells = 1
n_genes = 200
adata = adata[:n_cells, :n_genes].copy()
perturbation_column = "perturbation"
Expand All @@ -18,8 +18,25 @@ def run(cfg: DictConfig):
c2s = Cell2Sen(configurer=config)

processed_dataset = c2s.process_data(adata)
embeddings = c2s.get_embeddings(processed_dataset)
embeddings, attentions, genes_names_attn = c2s.get_embeddings(processed_dataset,output_attentions=True)
perturbed_dataset, perturbed_cell_sentences = c2s.get_perturbations(processed_dataset)
# Print the first cell sentence and its words for comparison
first_sentence = processed_dataset['cell_sentence'][0]
words = first_sentence.split()
print(f"\nFirst cell sentence ({len(words)} words): {first_sentence}")

# Show how each gene gets tokenized into subtokens
print("\nGene -> subtokens:")
for gene in words:
token_ids = c2s.tokenizer.encode(gene, add_special_tokens=False)
subtokens = c2s.tokenizer.convert_ids_to_tokens(token_ids)
print(f" {gene:>20s} -> {subtokens}")
print(f" Total genes: {len(words)}, Total subtokens (genes only): {sum(len(c2s.tokenizer.encode(g, add_special_tokens=False)) for g in words)}")

# attentions is a tuple of lists: attentions[layer][sample] -> (num_heads, num_words, num_words)
print(f"\nNumber of layers: {len(attentions)}")
print(f"Number of samples: {len(attentions[0])}")
print(f"First sample, first layer shape: {attentions[0][0].shape}")

if __name__ == "__main__":
run()
3 changes: 2 additions & 1 deletion examples/run_models/run_scgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ def run(cfg: DictConfig):
embeddings, attn_weights = scgpt.get_embeddings(data, output_attentions=True)

print(embeddings)
print(attn_weights.shape)
print(attn_weights[0].shape)
print(len(attn_weights))


if __name__ == "__main__":
Expand Down
15 changes: 10 additions & 5 deletions helical/models/scgpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def get_embeddings(
dataset: Dataset,
output_attentions: bool = False,
output_genes: bool = False,
attn_layer: int = -1,
) -> np.array:
"""Gets the gene embeddings

Expand All @@ -103,15 +104,18 @@ def get_embeddings(
If set to False, only the embeddings will be returned. **Note**: This will increase the memory usage of the model significantly, so use it only if you need the attention maps.
output_genes : bool, optional, default=False
Whether to output the genes corresponding to the embeddings. If set to True, the genes will be returned as a list of strings corresponding to the embeddings.
attn_layer : int, optional, default=-1
Which transformer layer's attention to return. Supports negative indexing (e.g. -1 for the last layer).
Only used when output_attentions is True.

Returns
-------
np.ndarray | List[pd.Series]
The embeddings produced by the model.
The return type depends on the `emb_mode` parameter in the configuration.
If `emb_mode` is set to "gene", the embeddings are returned as a list of pd.Series which contain a mapping of gene_name:embedding for each cell.
np.ndarray
If `output_attentions` is set to True, the attention maps will be returned as a numpy array of shape (n_layers, n_heads, n_cells, n_tokens, n_tokens).
list[np.ndarray]
If `output_attentions` is set to True, a list of per-sample attention maps, each of shape (n_heads, seq_len, seq_len).
list, optional
If `output_genes` is set to True, the genes corresponding to the embeddings will be returned as a list of strings.
Each element in the list corresponds to the genes for each input in the dataset.
Expand Down Expand Up @@ -177,7 +181,8 @@ def get_embeddings(
),
output_attentions=output_attentions,
)
resulting_attn_maps.extend(attn_maps)
# Select the requested layer: (batch, n_heads, seq, seq)
resulting_attn_maps.extend(attn_maps[attn_layer].cpu().numpy())
else:
embeddings = self.model._encode(
input_gene_ids,
Expand Down Expand Up @@ -211,11 +216,11 @@ def get_embeddings(
if output_attentions and output_genes:
return (
resulting_embeddings,
torch.stack(resulting_attn_maps).cpu().numpy(),
resulting_attn_maps,
input_genes,
)
elif output_attentions:
return resulting_embeddings, torch.stack(resulting_attn_maps).cpu().numpy()
return resulting_embeddings, resulting_attn_maps
elif output_genes:
return resulting_embeddings, input_genes
else:
Expand Down
51 changes: 16 additions & 35 deletions helical/models/tahoe/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def get_embeddings(
dataloader: DataLoader,
return_gene_embeddings: bool = False,
output_attentions: bool = False,
attn_layer: int = -1,
) -> Union[np.ndarray, tuple]:
"""Gets the embeddings from the Tahoe model.

Expand All @@ -217,6 +218,9 @@ def get_embeddings(
Note: This requires the model to be initialized with attn_impl='torch'.
The default Flash Attention (attn_impl='flash') does not support attention
weight extraction for efficiency reasons.
attn_layer : int, optional, default=-1
Which transformer layer's attention to return. Supports negative indexing
(e.g. -1 for the last layer). Only used when output_attentions is True.

Returns
-------
Expand All @@ -231,10 +235,8 @@ def get_embeddings(
- cell_embeddings: numpy array of shape (n_cells, embedding_dim)
- gene_embeddings: list of pandas Series, one per cell. Each Series contains
gene embeddings indexed by Ensembl IDs for genes expressed in that cell.
- attentions: numpy array containing attention weights from the last transformer layer.
Shape: (n_batches, batch_size, n_heads, seq_length, seq_length).
Sequence lengths vary per batch based on the number of genes expressed.
Only the last transformer layer's attention is returned to conserve memory.
- attentions: list of per-sample numpy arrays, each of shape (n_heads, seq_length, seq_length).
Sequence lengths vary per sample based on the number of genes expressed.
"""
LOGGER.info("Extracting embeddings from Tahoe model...")

Expand Down Expand Up @@ -296,11 +298,9 @@ def get_embeddings(
cell_embs.append(output["cell_emb"].to("cpu").to(dtype=torch.float32))

if output_attentions:
# Only keep last layer attention to save memory
# Shape: (batch, n_heads, seq_len, seq_len)
# Convert to float32 for numpy compatibility
last_layer_attn = output["attentions"][-1].cpu().to(torch.float32)
all_attentions.append(last_layer_attn)
# Select the requested layer: (batch, n_heads, seq_len, seq_len)
layer_attn = output["attentions"][attn_layer].cpu().to(torch.float32)
all_attentions.append(layer_attn)

if return_gene_embeddings:
# Get gene embeddings for this batch: shape (batch_size, seq_len, d_model)
Expand Down Expand Up @@ -332,47 +332,28 @@ def get_embeddings(
)


# Prepare attention arrays if requested
# Prepare attention list if requested — one np.ndarray per sample
if output_attentions:
# Find max sequence length across all batches
max_seq_len = max(attn.shape[2] for attn in all_attentions)

# Pad all batches to max_seq_len
padded_attentions = []
attn_list = []
for attn in all_attentions:
batch_size, n_heads, seq_len, _ = attn.shape
if seq_len < max_seq_len:
# Pad with zeros to max_seq_len
pad_size = max_seq_len - seq_len
padded = torch.nn.functional.pad(
attn,
(0, pad_size, 0, pad_size), # pad last 2 dimensions (seq_len, seq_len)
mode='constant',
value=0
)
padded_attentions.append(padded)
else:
padded_attentions.append(attn)

# Stack along first dimension and convert to numpy
# Shape: (n_batches, batch_size, n_heads, max_seq_len, max_seq_len)
attention_array = torch.cat(padded_attentions, dim=0).numpy()
# attn shape: (batch, n_heads, seq_len, seq_len)
attn_list.extend(attn.numpy())

# Return based on requested outputs
log_msg = f"Finished extracting embeddings. Cell shape: {cell_array.shape}"
if return_gene_embeddings:
log_msg += f", Gene embeddings: {len(all_gene_embeddings)} cells"
if output_attentions:
log_msg += f", Attention shape: {attention_array.shape}"
log_msg += f", Attention maps: {len(attn_list)} samples"
LOGGER.info(log_msg)

# Return appropriate combination
if return_gene_embeddings and output_attentions:
return cell_array, all_gene_embeddings, attention_array
return cell_array, all_gene_embeddings, attn_list
elif return_gene_embeddings:
return cell_array, all_gene_embeddings
elif output_attentions:
return cell_array, attention_array
return cell_array, attn_list
else:
return cell_array

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "helical"
version = "1.7.0"
version = "1.8.0"
authors = [
{ name="Helical Team", email="support@helical-ai.com" },
]
Expand Down
Loading