Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 101 additions & 16 deletions examples/notebooks/Cell2Sen-Tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,24 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"id": "66aa1ee7",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2025-11-26 15:35:49,900 - WARNING:py.warnings:/home/simon/miniconda3/envs/helical_dev/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"\n",
"2025-11-26 15:35:50,266 - INFO:datasets:PyTorch version 2.7.0 available.\n",
"2025-11-26 15:35:50,267 - INFO:datasets:Polars version 0.20.31 available.\n",
"2025-11-26 15:35:50,583 - INFO:helical.utils.downloader:Starting to download: 'https://huggingface.co/datasets/helical-ai/yolksac_human/resolve/main/data/17_04_24_YolkSacRaw_F158_WE_annots.h5ad?download=true'\n",
"yolksac_human.h5ad: 100%|██████████| 553M/553M [00:04<00:00, 116MB/s] \n"
]
}
],
"source": [
"from helical.utils.downloader import Downloader\n",
"from pathlib import Path\n",
Expand All @@ -38,18 +52,26 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"id": "20493054",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(10, 37318)\n",
"10\n"
]
}
],
"source": [
"import anndata as ad\n",
"\n",
"adata = ad.read_h5ad(\"./yolksac_human.h5ad\")\n",
"# We subset to 10 cells and 2000 genes\n",
"n_cells = 10\n",
"n_genes = 200\n",
"adata = adata[:n_cells, :n_genes].copy()\n",
"adata = adata[:n_cells].copy()\n",
"\n",
"# we can specify the perturbations for each cell in the anndata or later as well in get_pertubations\n",
"perturbation_column = \"perturbation\"\n",
Expand All @@ -73,7 +95,20 @@
"execution_count": null,
"id": "63d0c2ae",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2025-11-26 15:36:17,060 - WARNING:py.warnings:/home/simon/miniconda3/envs/helical_dev/lib/python3.11/site-packages/louvain/__init__.py:54: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.\n",
" from pkg_resources import get_distribution, DistributionNotFound\n",
"\n",
"2025-11-26 15:36:18,575 - INFO:helical.models.c2s.model:Using SDPA for attention implementation - default for CPU\n",
"Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00, 1.55s/it]\n",
"2025-11-26 15:36:25,397 - INFO:helical.models.c2s.model:Successfully loaded model\n"
]
}
],
"source": [
"from helical.models.c2s import Cell2Sen\n",
"from helical.models.c2s import Cell2SenConfig\n",
Expand All @@ -82,15 +117,29 @@
"# when calling the model class both the model and weights are downloaded - we can choose the model size (\"2B\" vs \"27B\" Gemma model)\n",
"# if you would like to use 4-bit quantization for reduced memory usage, set use_quantization=True in the config\n",
"# on GPU devices, you can also use flash attention 2 by setting use_flash_attn=True in the config\n",
"# provide max_genes to only select the top genes in the ranked list\n",
"# provide max_genes to only select the top genes in the ranked list and save computation time\n",
"# See the config file for more details\n",
"\n",
"\n",
"# You can provide a custom prompt to the model, depending on your specific task. Below you see an example of how you can structure \n",
"# such a prompt (this is also the default prompt we use if you do not pass anything). Keep in mind to test your prompt and evaluate\n",
"# results before you use them, as the model was only trained on limited prompt type and may not react to yours in the way you expect.\n",
"custom_prompt = \"\"\"\n",
" You are given a list of genes in descending order of expression levels in a {organism} cell. \\n\n",
" Genes: {cell_sentence} \\n\n",
" Using this information, predict the cell type. Answer: \n",
" \"\"\"\n",
"\n",
"config = Cell2SenConfig(\n",
" batch_size=8, \n",
" perturbation_column=perturbation_column, \n",
" model_size=\"2B\", \n",
" device=\"cuda\" if torch.cuda.is_available() else \"cpu\",\n",
" use_quantization=True)\n",
" use_quantization=True,\n",
" max_genes=50,\n",
" aggregation_type=\"mean_pool\",\n",
" embedding_prompt_template=custom_prompt)\n",
"\n",
"cell2sen_model = Cell2Sen(configurer=config)"
]
},
Expand All @@ -104,13 +153,32 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 8,
"id": "351e1e80",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2025-11-26 15:38:31,931 - INFO:helical.models.c2s.model:Processing data\n",
"Processing cells: 100%|██████████| 10/10 [00:00<00:00, 3826.22it/s]\n",
"2025-11-26 15:38:31,951 - INFO:helical.models.c2s.model:Successfully processed data\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'cell_sentence': 'MALAT1 MEG3 TTR AFP APOA1 MT-CO1 ALB APOE MT-CO2 KLF6 APOC3 VTN APOA2 NEAT1 TF MTRNR2L12 SERPINA1 MT-CYB MEG8 MT-ATP6 MT1G MT-CO3 APOB MT-ND6 SAT1 APOC2 MT-ND4 RPS15 CST3 APOM RPS18 H3F3B MT1H GPC3 JUND FADS1 TIMP3 APOC1 MT-ND4L FLVCR1 EEF1A1 RPL41 RPL37A MT-ND1 JUN H19 PHACTR2 RPL13 UBC REEP6', 'fit_parameters': None, 'organism': 'unknown', 'perturbations': 'IFNg'}\n",
"#Genes in sample: 50\n"
]
}
],
"source": [
"processes_dataset = cell2sen_model.process_data(adata)\n",
"print(processes_dataset[0])"
"print(processes_dataset[0])\n",
"print(f'#Genes in sample: {len(processes_dataset[0][\"cell_sentence\"].split(\" \"))}')"
]
},
{
Expand All @@ -125,10 +193,27 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"id": "2d5ac9f3",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2025-11-26 15:38:39,235 - INFO:helical.models.c2s.model:Extracting embeddings from dataset\n",
"Processing embeddings: 100%|██████████| 10/10 [00:00<00:00, 21.55it/s]\n",
"2025-11-26 15:38:39,715 - INFO:helical.models.c2s.model:Successfully extracted embeddings\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(10, 2304)\n"
]
}
],
"source": [
"# set output_attentions=True to get the attention maps - this will return attentions for each layer in the model per head\n",
"\n",
Expand Down Expand Up @@ -188,7 +273,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "helical",
"display_name": "helical_dev",
"language": "python",
"name": "python3"
},
Expand All @@ -202,7 +287,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.14"
"version": "3.11.13"
}
},
"nbformat": 4,
Expand Down
18 changes: 17 additions & 1 deletion helical/models/c2s/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,18 @@ class Cell2SenConfig:
Maximum number of genes to use for the model. Default is 200.
If None, all genes will be used.
If a number is provided, the genes will be sorted by expression level and the top max_genes will be used.

aggregation_type: Literal["mean_pool", "last_token"] = "mean_pool"
How to aggregate final-layer hidden states into a single embedding. Defaults to "mean_pool".
"mean_pool": Computes the mean of all non-padding token embeddings in the last layer.
"last_token": Uses only the embedding of the final non-padding token (i.e., the position where the model would predict the next token).

embedding_prompt_template: str = None
Optional custom embedding prompt template used to query the model.
If None, a default built-in prompt template is used.
Example: 'You are given a list of genes in descending order of expression levels in a {organism} cell. \n
Genes: {cell_sentence} \n
Using this information, describe the function of the cell in a few words. Answer:'

device: Literal["cpu", "cuda"] = "cpu"
Device to use for the model. Default is "cpu".
Expand All @@ -73,7 +85,9 @@ def __init__(
organism: str = None,
perturbation_column: str = None,
max_new_tokens: int = 200,
max_genes: int = 200,
max_genes: int = None,
aggregation_type: Literal["mean_pool", "last_token"] = "mean_pool",
embedding_prompt_template: str = None,
return_fit: bool = False,
dtype: str = "bfloat16",
model_size: str = "2B",
Expand Down Expand Up @@ -137,5 +151,7 @@ def __init__(
"model_size": model_size,
"use_flash_attn": use_flash_attn,
"max_genes": max_genes,
"aggregation_type": aggregation_type,
"embedding_prompt_template": embedding_prompt_template,
"device": device,
}
54 changes: 39 additions & 15 deletions helical/models/c2s/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(self, configurer: Cell2SenConfig = None) -> None:
# downloader.download_via_name(file)

self.device = self.config["device"]
if self.device == "cuda" and self.config["use_flash_attn"]:
if "cuda" in self.device and self.config["use_flash_attn"]:
LOGGER.info("Using flash attention 2 for attention implementation")
self.attn_implementation = "flash_attention_2"
else:
Expand Down Expand Up @@ -100,7 +100,8 @@ def __init__(self, configurer: Cell2SenConfig = None) -> None:
cache_dir=self.config["model_path"],
quantization_config=self.bnb_config,
attn_implementation=self.attn_implementation,
).to(self.device)
device_map=self.device
)

self.tokenizer = AutoTokenizer.from_pretrained(self.config["hf_model_path"], cache_dir=self.config["model_path"])
self.model.eval()
Expand All @@ -111,6 +112,8 @@ def __init__(self, configurer: Cell2SenConfig = None) -> None:
self.perturbation_column = self.config['perturbation_column']
self.return_fit = self.config['return_fit']
self.max_genes = self.config['max_genes']
self.aggregation_type = self.config["aggregation_type"]
self.embedding_prompt_template = self.config["embedding_prompt_template"]

LOGGER.info("Successfully loaded model")

Expand Down Expand Up @@ -256,7 +259,7 @@ def process_data(
def get_embeddings(
self,
dataset: Dataset,
output_attentions: bool = False
output_attentions: bool = False,
):
"""
Extract embeddings from cell sentences in a HuggingFace Dataset using the last hidden layer of Gemma.
Expand Down Expand Up @@ -288,10 +291,16 @@ def get_embeddings(
batch_sentences = sentences_list[i:i + self.batch_size]
batch_organisms = organisms_list[i:i + self.batch_size]

prompts = [
EMBEDDING_PROMPT.format(organism=org, cell_sentence=cs)
for org, cs in zip(batch_organisms, batch_sentences)
]
if self.embedding_prompt_template is None:
prompts = [
EMBEDDING_PROMPT.format(organism=org, cell_sentence=cs)
for org, cs in zip(batch_organisms, batch_sentences)
]
else:
prompts = [
self.embedding_prompt_template.format(organism=org, cell_sentence=cs)
for org, cs in zip(batch_organisms, batch_sentences)
]

inputs = self.tokenizer(
prompts,
Expand All @@ -308,14 +317,29 @@ def get_embeddings(
output_hidden_states=True,
output_attentions=output_attentions
)
last_hidden = outputs.hidden_states[-1] # Shape: (batch_size, seq_len, hidden_size)
attention_mask = inputs['attention_mask'].float() # Shape: (batch_size, seq_len, 1)

# Sum embeddings over sequence length, masking out padding tokens
sum_embeddings = torch.sum(last_hidden * attention_mask.unsqueeze(-1), dim=1)
sum_mask = torch.clamp(attention_mask.sum(dim=1, keepdim=True), min=1e-9)
batch_embeddings = sum_embeddings / sum_mask

last_hidden = outputs.hidden_states[-1] # (B, L, H)
attention_mask = inputs['attention_mask'].float() # (B, L)

if self.aggregation_type == 'mean_pool':
# mean pooling over non-padding tokens
masked_hidden = last_hidden * attention_mask.unsqueeze(-1) # (B, L, H)
sum_embeddings = masked_hidden.sum(dim=1) # (B, H)
sum_mask = attention_mask.sum(dim=1, keepdim=True).clamp(min=1e-9)
batch_embeddings = sum_embeddings / sum_mask # (B, H)

elif self.aggregation_type == 'last_token':
# index of last non-padding token
last_idx = (attention_mask.sum(dim=1) - 1).long() # (B,)

# gather token representations
batch_embeddings = last_hidden[
torch.arange(last_hidden.size(0), device=last_hidden.device),
last_idx
] # (B, H)

else:
raise ValueError("Invalid aggregation type. Use 'mean_pool' or 'last_token'.")

if output_attentions:
# outputs.attentions is a tuple of tensors, one per layer
# Each tensor has shape (batch_size, num_heads, seq_length, seq_length)
Expand Down
4 changes: 4 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ repo_name: helicalAI/helical
copyright: <a href="https://www.helical-ai.com/">Helical Team</a> <br> Copyright &copy; 2024
nav:
- Single-Cell Models:
- Cell2Sen-Scale:
- Model Card: ./model_cards/c2s.md
- Config: ./configs/c2s_config.md
- Model: ./models/c2s.md
- Geneformer:
- Model Card: ./model_cards/geneformer.md
- Config: ./configs/geneformer_config.md
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "helical"
version = "1.4.19"
version = "1.4.20"
authors = [
{ name="Helical Team", email="support@helical-ai.com" },
]
Expand Down
Loading