Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 77 additions & 3 deletions ci/tests/test_nicheformer/test_nicheformer_model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,52 @@
import types

import pytest
import numpy as np
import torch
import torch.nn as nn
import anndata as ad
from anndata import AnnData
from scipy.sparse import csr_matrix
from datasets import Dataset

from helical.models.nicheformer import Nicheformer, NicheformerConfig

_STUB_DIM = 8
_STUB_NHEADS = 2
_STUB_SEQ_LEN = 10
_STUB_VOCAB_SIZE = 40


@pytest.fixture
def stub_bert():
"""Minimal real NicheformerModel-compatible module for attention tests."""

class _StubEncoder(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.ModuleList(
[
nn.TransformerEncoderLayer(
d_model=_STUB_DIM,
nhead=_STUB_NHEADS,
batch_first=True,
norm_first=False,
)
]
)

class _StubBert(nn.Module):
def __init__(self):
super().__init__()
self.config = types.SimpleNamespace(nlayers=1, learnable_pe=True)
self.embeddings = nn.Embedding(_STUB_VOCAB_SIZE, _STUB_DIM)
self.positional_embedding = nn.Embedding(_STUB_SEQ_LEN, _STUB_DIM)
self.pos = torch.arange(0, _STUB_SEQ_LEN)
self.dropout = nn.Dropout(0.0)
self.encoder = _StubEncoder()

return _StubBert()


@pytest.fixture
def _mocks(mocker):
Expand All @@ -24,7 +63,11 @@ def _tokenize(adata, **kwargs):
}

mock_tokenizer.side_effect = _tokenize
mock_tokenizer.get_vocab.return_value = {"GENE1": 30, "GENE2": 31, "GENE3": 32}
mock_tokenizer.get_vocab.return_value = {
"ENSG00000000001": 30,
"ENSG00000000002": 31,
"ENSG00000000003": 32,
}
mocker.patch(
"helical.models.nicheformer.model.AutoTokenizer.from_pretrained",
return_value=mock_tokenizer,
Expand Down Expand Up @@ -54,7 +97,7 @@ def nicheformer(_mocks):
def mock_adata():
adata = AnnData(X=np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32))
adata.obs_names = ["cell1", "cell2", "cell3"]
adata.var_names = ["GENE1", "GENE2", "GENE3"]
adata.var_names = ["ENSG00000000001", "ENSG00000000002", "ENSG00000000003"]
return adata


Expand Down Expand Up @@ -106,7 +149,7 @@ def test_float_counts_raises_value_error(self, nicheformer):
adata.X = adata.X.astype(float)
adata.X[0, 0] = 0.5
with pytest.raises(ValueError):
nicheformer.process_data(adata, gene_names="index")
nicheformer.process_data(adata)

def test_no_vocab_genes_raises_value_error(self, nicheformer, mock_adata, _mocks):
mock_tokenizer, _ = _mocks
Expand Down Expand Up @@ -155,6 +198,37 @@ def test_with_context_forwarded_to_model(self, nicheformer, mock_adata, _mocks):
nicheformer.get_embeddings(dataset)
assert mock_model.get_embeddings.call_args.kwargs["with_context"] is True

def test_attention_shape_invariant_to_masking(self, nicheformer, stub_bert):
nicheformer.model.bert = stub_bert
n_obs = 2
input_ids = torch.zeros((n_obs, _STUB_SEQ_LEN), dtype=torch.long)

mask_full = torch.ones((n_obs, _STUB_SEQ_LEN), dtype=torch.bool)
attn_full = nicheformer._extract_attention_weights(
input_ids, mask_full, layer=-1
)

mask_partial = mask_full.clone()
mask_partial[:, _STUB_SEQ_LEN // 2 :] = False
attn_partial = nicheformer._extract_attention_weights(
input_ids, mask_partial, layer=-1
)

expected = (n_obs, _STUB_NHEADS, _STUB_SEQ_LEN, _STUB_SEQ_LEN)
assert attn_full.shape == expected
assert attn_partial.shape == expected

def test_output_attentions_shape(self, nicheformer, mock_adata, mocker):
n_obs, n_heads, seq_len = mock_adata.n_obs, 16, 1500
mocker.patch.object(
nicheformer,
"_extract_attention_weights",
return_value=torch.zeros(n_obs, n_heads, seq_len, seq_len),
)
dataset = nicheformer.process_data(mock_adata)
_, attentions = nicheformer.get_embeddings(dataset, output_attentions=True)
assert attentions.shape == (n_obs, n_heads, seq_len, seq_len)


class TestNicheformerTechnologyMean:
def test_none_does_not_call_load(self, _mocks, mocker):
Expand Down
13 changes: 12 additions & 1 deletion examples/run_models/run_nicheformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,18 @@ def run(cfg: DictConfig):
dataset = nicheformer.process_data(ann_data[:10])

cell_embeddings = nicheformer.get_embeddings(dataset)
print(f"Cell embeddings shape: {cell_embeddings.shape}")
print(f"Cell embeddings shape (Ensembl IDs): {cell_embeddings.shape}")

# yolksac uses gene symbols — exercises the symbol-to-Ensembl mapping path.
ann_data_yolksac = ad.read_h5ad("./yolksac_human.h5ad")
ann_data_yolksac.obs["modality"] = "dissociated"
ann_data_yolksac.obs["specie"] = "human"
ann_data_yolksac.obs["assay"] = "10x 3' v3"

dataset_yolksac = nicheformer.process_data(ann_data_yolksac[:10])

cell_embeddings_yolksac = nicheformer.get_embeddings(dataset_yolksac)
print(f"Cell embeddings shape (gene symbols): {cell_embeddings_yolksac.shape}")


if __name__ == "__main__":
Expand Down
111 changes: 107 additions & 4 deletions helical/models/nicheformer/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from helical.models.base_models import HelicalRNAModel
from helical.models.nicheformer.nicheformer_config import NicheformerConfig
from helical.utils.downloader import Downloader
from helical.utils.mapping import map_gene_symbols_to_ensembl_ids
from anndata import AnnData
from datasets import Dataset
from transformers import AutoModelForMaskedLM, AutoTokenizer
Expand Down Expand Up @@ -109,7 +110,11 @@ def process_data(
are used to prepend context tokens when present.
gene_names : str, optional, default="index"
The column in ``adata.var`` that contains gene names. If set to
``"index"``, the index of ``adata.var`` is used.
``"index"``, the index of ``adata.var`` is used. If set to
``"ensembl_id"``, no symbol-to-Ensembl mapping is performed and
``adata.var_names`` must already be Ensembl IDs. Otherwise the
symbols in the given column are mapped to Ensembl IDs via the
static BioMart table.
use_raw_counts : bool, optional, default=True
Whether to validate that the expression matrix contains raw integer
counts.
Expand All @@ -121,7 +126,24 @@ def process_data(
columns, ready for :meth:`get_embeddings`.
"""
LOGGER.info("Processing data for Nicheformer.")
self.ensure_rna_data_validity(adata, gene_names, use_raw_counts)
# When gene_names="ensembl_id" the IDs live in var_names (the index),
# not in a separate column, so validate against "index".
self.ensure_rna_data_validity(
adata, "index" if gene_names == "ensembl_id" else gene_names, use_raw_counts
)

if gene_names != "ensembl_id":
col = adata.var[gene_names]
if not col.str.startswith("ENS").all():
adata = map_gene_symbols_to_ensembl_ids(
adata, gene_names if gene_names != "index" else None
)
if adata.var["ensembl_id"].isnull().all():
message = "All gene symbols could not be mapped to Ensembl IDs. Please check the input data."
LOGGER.error(message)
raise ValueError(message)
adata = adata[:, adata.var["ensembl_id"].notnull()]
adata.var_names = adata.var["ensembl_id"].values

ref_genes = {k for k in self.tokenizer.get_vocab() if not k.startswith("[")}
_original_gene_count = len(adata.var_names)
Expand All @@ -147,7 +169,11 @@ def process_data(
LOGGER.info("Successfully processed data for Nicheformer.")
return dataset

def get_embeddings(self, dataset: Dataset) -> np.ndarray:
def get_embeddings(
self,
dataset: Dataset,
output_attentions: bool = False,
) -> np.ndarray:
"""Extracts cell embeddings from a processed dataset using Nicheformer.

Embeddings are obtained by mean-pooling over the sequence dimension at
Expand All @@ -157,11 +183,18 @@ def get_embeddings(self, dataset: Dataset) -> np.ndarray:
----------
dataset : Dataset
The processed dataset returned by :meth:`process_data`.
output_attentions : bool, optional, default=False
Whether to return per-head attention weights from the target
transformer layer. When ``True`` a second array is returned with
shape ``(n_cells, n_heads, seq_length, seq_length)``.

Returns
-------
np.ndarray
Cell embeddings of shape ``(n_cells, 512)``.
np.ndarray, optional
Attention weights of shape ``(n_cells, n_heads, seq_length,
seq_length)``, only returned when ``output_attentions=True``.
"""
LOGGER.info("Started getting embeddings for Nicheformer.")

Expand All @@ -171,6 +204,7 @@ def get_embeddings(self, dataset: Dataset) -> np.ndarray:

dataset.set_format(type="torch")
all_embeddings = []
all_attentions = [] if output_attentions else None

for i in range(0, len(dataset), batch_size):
batch = dataset[i : i + batch_size]
Expand All @@ -185,7 +219,76 @@ def get_embeddings(self, dataset: Dataset) -> np.ndarray:
with_context=with_context,
)

if output_attentions:
attn = self._extract_attention_weights(
input_ids, attention_mask, layer
)
all_attentions.append(attn.cpu().numpy())

all_embeddings.append(embeddings.cpu().numpy())

LOGGER.info("Finished getting embeddings for Nicheformer.")
return np.concatenate(all_embeddings, axis=0)
embeddings_out = np.concatenate(all_embeddings, axis=0)
if output_attentions:
return embeddings_out, np.concatenate(all_attentions, axis=0)
return embeddings_out

def _extract_attention_weights(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
layer: int,
) -> torch.Tensor:
"""Return per-head attention weights from the target encoder layer.

Runs the embedding preparation and all encoder layers up to and
including ``layer``, extracting the self-attention weights at that
layer via ``MultiheadAttention`` with ``need_weights=True``.

Parameters
----------
input_ids : torch.Tensor
Token IDs of shape ``(batch, seq_len)``.
attention_mask : torch.Tensor
Boolean mask of shape ``(batch, seq_len)`` — ``True`` for real
tokens.
layer : int
Target layer index (negative values count from the last layer).

Returns
-------
torch.Tensor
Attention weights of shape ``(batch, n_heads, seq_len, seq_len)``.
"""
bert = self.model.bert
layer_idx = bert.config.nlayers + layer if layer < 0 else layer

token_embedding = bert.embeddings(input_ids)
if bert.config.learnable_pe:
pos_embedding = bert.positional_embedding(
bert.pos.to(token_embedding.device)
)
x = bert.dropout(token_embedding + pos_embedding)
else:
x = bert.positional_embedding(token_embedding)

padding_mask = ~attention_mask.bool()

for i in range(layer_idx + 1):
if i == layer_idx:
x_in = x
x = bert.encoder.layers[i](
x, src_key_padding_mask=padding_mask, is_causal=False
)

enc_layer = bert.encoder.layers[layer_idx]
query = enc_layer.norm1(x_in) if enc_layer.norm_first else x_in
_, attn_weights = enc_layer.self_attn(
query,
query,
query,
key_padding_mask=padding_mask,
need_weights=True,
average_attn_weights=False,
)
return attn_weights
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "helical"
version = "1.10.0"
version = "1.10.1"
authors = [
{ name="Helical Team", email="support@helical-ai.com" },
]
Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@ pytest==8.2.0
pytest-cov==5.0.0
pytest-mock==3.14.0
nbmake==1.5.4
black==26.1.0
black==26.3.1

Loading