From f3060daa4b74bd769322712e73948d2dc7b8ffee Mon Sep 17 00:00:00 2001 From: "Dmitry Ivanov @ helical-ai.com" Date: Thu, 19 Mar 2026 15:08:21 +0100 Subject: [PATCH] Nicheformer: ENSEMBL and attn maps (#354) We are adding: 1. Conversion of gene IDs to ENSEMBL - #352 2. Facilities for attention map extraction - #353 We are bumping a dependency version: black 26.1.0 -> 26.3.1 Bump version --- .../test_nicheformer_model.py | 80 ++++++++++++- examples/run_models/run_nicheformer.py | 13 +- helical/models/nicheformer/model.py | 111 +++++++++++++++++- pyproject.toml | 2 +- requirements-dev.txt | 2 +- 5 files changed, 198 insertions(+), 10 deletions(-) diff --git a/ci/tests/test_nicheformer/test_nicheformer_model.py b/ci/tests/test_nicheformer/test_nicheformer_model.py index e2b35b9b..b3d49b51 100644 --- a/ci/tests/test_nicheformer/test_nicheformer_model.py +++ b/ci/tests/test_nicheformer/test_nicheformer_model.py @@ -1,6 +1,9 @@ +import types + import pytest import numpy as np import torch +import torch.nn as nn import anndata as ad from anndata import AnnData from scipy.sparse import csr_matrix @@ -8,6 +11,42 @@ from helical.models.nicheformer import Nicheformer, NicheformerConfig +_STUB_DIM = 8 +_STUB_NHEADS = 2 +_STUB_SEQ_LEN = 10 +_STUB_VOCAB_SIZE = 40 + + +@pytest.fixture +def stub_bert(): + """Minimal real NicheformerModel-compatible module for attention tests.""" + + class _StubEncoder(nn.Module): + def __init__(self): + super().__init__() + self.layers = nn.ModuleList( + [ + nn.TransformerEncoderLayer( + d_model=_STUB_DIM, + nhead=_STUB_NHEADS, + batch_first=True, + norm_first=False, + ) + ] + ) + + class _StubBert(nn.Module): + def __init__(self): + super().__init__() + self.config = types.SimpleNamespace(nlayers=1, learnable_pe=True) + self.embeddings = nn.Embedding(_STUB_VOCAB_SIZE, _STUB_DIM) + self.positional_embedding = nn.Embedding(_STUB_SEQ_LEN, _STUB_DIM) + self.pos = torch.arange(0, _STUB_SEQ_LEN) + self.dropout = nn.Dropout(0.0) + self.encoder = _StubEncoder() + + return _StubBert() + @pytest.fixture def _mocks(mocker): @@ -24,7 +63,11 @@ def _tokenize(adata, **kwargs): } mock_tokenizer.side_effect = _tokenize - mock_tokenizer.get_vocab.return_value = {"GENE1": 30, "GENE2": 31, "GENE3": 32} + mock_tokenizer.get_vocab.return_value = { + "ENSG00000000001": 30, + "ENSG00000000002": 31, + "ENSG00000000003": 32, + } mocker.patch( "helical.models.nicheformer.model.AutoTokenizer.from_pretrained", return_value=mock_tokenizer, @@ -54,7 +97,7 @@ def nicheformer(_mocks): def mock_adata(): adata = AnnData(X=np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32)) adata.obs_names = ["cell1", "cell2", "cell3"] - adata.var_names = ["GENE1", "GENE2", "GENE3"] + adata.var_names = ["ENSG00000000001", "ENSG00000000002", "ENSG00000000003"] return adata @@ -106,7 +149,7 @@ def test_float_counts_raises_value_error(self, nicheformer): adata.X = adata.X.astype(float) adata.X[0, 0] = 0.5 with pytest.raises(ValueError): - nicheformer.process_data(adata, gene_names="index") + nicheformer.process_data(adata) def test_no_vocab_genes_raises_value_error(self, nicheformer, mock_adata, _mocks): mock_tokenizer, _ = _mocks @@ -155,6 +198,37 @@ def test_with_context_forwarded_to_model(self, nicheformer, mock_adata, _mocks): nicheformer.get_embeddings(dataset) assert mock_model.get_embeddings.call_args.kwargs["with_context"] is True + def test_attention_shape_invariant_to_masking(self, nicheformer, stub_bert): + nicheformer.model.bert = stub_bert + n_obs = 2 + input_ids = torch.zeros((n_obs, _STUB_SEQ_LEN), dtype=torch.long) + + mask_full = torch.ones((n_obs, _STUB_SEQ_LEN), dtype=torch.bool) + attn_full = nicheformer._extract_attention_weights( + input_ids, mask_full, layer=-1 + ) + + mask_partial = mask_full.clone() + mask_partial[:, _STUB_SEQ_LEN // 2 :] = False + attn_partial = nicheformer._extract_attention_weights( + input_ids, mask_partial, layer=-1 + ) + + expected = (n_obs, _STUB_NHEADS, _STUB_SEQ_LEN, _STUB_SEQ_LEN) + assert attn_full.shape == expected + assert attn_partial.shape == expected + + def test_output_attentions_shape(self, nicheformer, mock_adata, mocker): + n_obs, n_heads, seq_len = mock_adata.n_obs, 16, 1500 + mocker.patch.object( + nicheformer, + "_extract_attention_weights", + return_value=torch.zeros(n_obs, n_heads, seq_len, seq_len), + ) + dataset = nicheformer.process_data(mock_adata) + _, attentions = nicheformer.get_embeddings(dataset, output_attentions=True) + assert attentions.shape == (n_obs, n_heads, seq_len, seq_len) + class TestNicheformerTechnologyMean: def test_none_does_not_call_load(self, _mocks, mocker): diff --git a/examples/run_models/run_nicheformer.py b/examples/run_models/run_nicheformer.py index 40223d0c..a7da064d 100644 --- a/examples/run_models/run_nicheformer.py +++ b/examples/run_models/run_nicheformer.py @@ -19,7 +19,18 @@ def run(cfg: DictConfig): dataset = nicheformer.process_data(ann_data[:10]) cell_embeddings = nicheformer.get_embeddings(dataset) - print(f"Cell embeddings shape: {cell_embeddings.shape}") + print(f"Cell embeddings shape (Ensembl IDs): {cell_embeddings.shape}") + + # yolksac uses gene symbols — exercises the symbol-to-Ensembl mapping path. + ann_data_yolksac = ad.read_h5ad("./yolksac_human.h5ad") + ann_data_yolksac.obs["modality"] = "dissociated" + ann_data_yolksac.obs["specie"] = "human" + ann_data_yolksac.obs["assay"] = "10x 3' v3" + + dataset_yolksac = nicheformer.process_data(ann_data_yolksac[:10]) + + cell_embeddings_yolksac = nicheformer.get_embeddings(dataset_yolksac) + print(f"Cell embeddings shape (gene symbols): {cell_embeddings_yolksac.shape}") if __name__ == "__main__": diff --git a/helical/models/nicheformer/model.py b/helical/models/nicheformer/model.py index 328c9113..e848a6fc 100644 --- a/helical/models/nicheformer/model.py +++ b/helical/models/nicheformer/model.py @@ -1,6 +1,7 @@ from helical.models.base_models import HelicalRNAModel from helical.models.nicheformer.nicheformer_config import NicheformerConfig from helical.utils.downloader import Downloader +from helical.utils.mapping import map_gene_symbols_to_ensembl_ids from anndata import AnnData from datasets import Dataset from transformers import AutoModelForMaskedLM, AutoTokenizer @@ -109,7 +110,11 @@ def process_data( are used to prepend context tokens when present. gene_names : str, optional, default="index" The column in ``adata.var`` that contains gene names. If set to - ``"index"``, the index of ``adata.var`` is used. + ``"index"``, the index of ``adata.var`` is used. If set to + ``"ensembl_id"``, no symbol-to-Ensembl mapping is performed and + ``adata.var_names`` must already be Ensembl IDs. Otherwise the + symbols in the given column are mapped to Ensembl IDs via the + static BioMart table. use_raw_counts : bool, optional, default=True Whether to validate that the expression matrix contains raw integer counts. @@ -121,7 +126,24 @@ def process_data( columns, ready for :meth:`get_embeddings`. """ LOGGER.info("Processing data for Nicheformer.") - self.ensure_rna_data_validity(adata, gene_names, use_raw_counts) + # When gene_names="ensembl_id" the IDs live in var_names (the index), + # not in a separate column, so validate against "index". + self.ensure_rna_data_validity( + adata, "index" if gene_names == "ensembl_id" else gene_names, use_raw_counts + ) + + if gene_names != "ensembl_id": + col = adata.var[gene_names] + if not col.str.startswith("ENS").all(): + adata = map_gene_symbols_to_ensembl_ids( + adata, gene_names if gene_names != "index" else None + ) + if adata.var["ensembl_id"].isnull().all(): + message = "All gene symbols could not be mapped to Ensembl IDs. Please check the input data." + LOGGER.error(message) + raise ValueError(message) + adata = adata[:, adata.var["ensembl_id"].notnull()] + adata.var_names = adata.var["ensembl_id"].values ref_genes = {k for k in self.tokenizer.get_vocab() if not k.startswith("[")} _original_gene_count = len(adata.var_names) @@ -147,7 +169,11 @@ def process_data( LOGGER.info("Successfully processed data for Nicheformer.") return dataset - def get_embeddings(self, dataset: Dataset) -> np.ndarray: + def get_embeddings( + self, + dataset: Dataset, + output_attentions: bool = False, + ) -> np.ndarray: """Extracts cell embeddings from a processed dataset using Nicheformer. Embeddings are obtained by mean-pooling over the sequence dimension at @@ -157,11 +183,18 @@ def get_embeddings(self, dataset: Dataset) -> np.ndarray: ---------- dataset : Dataset The processed dataset returned by :meth:`process_data`. + output_attentions : bool, optional, default=False + Whether to return per-head attention weights from the target + transformer layer. When ``True`` a second array is returned with + shape ``(n_cells, n_heads, seq_length, seq_length)``. Returns ------- np.ndarray Cell embeddings of shape ``(n_cells, 512)``. + np.ndarray, optional + Attention weights of shape ``(n_cells, n_heads, seq_length, + seq_length)``, only returned when ``output_attentions=True``. """ LOGGER.info("Started getting embeddings for Nicheformer.") @@ -171,6 +204,7 @@ def get_embeddings(self, dataset: Dataset) -> np.ndarray: dataset.set_format(type="torch") all_embeddings = [] + all_attentions = [] if output_attentions else None for i in range(0, len(dataset), batch_size): batch = dataset[i : i + batch_size] @@ -185,7 +219,76 @@ def get_embeddings(self, dataset: Dataset) -> np.ndarray: with_context=with_context, ) + if output_attentions: + attn = self._extract_attention_weights( + input_ids, attention_mask, layer + ) + all_attentions.append(attn.cpu().numpy()) + all_embeddings.append(embeddings.cpu().numpy()) LOGGER.info("Finished getting embeddings for Nicheformer.") - return np.concatenate(all_embeddings, axis=0) + embeddings_out = np.concatenate(all_embeddings, axis=0) + if output_attentions: + return embeddings_out, np.concatenate(all_attentions, axis=0) + return embeddings_out + + def _extract_attention_weights( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + layer: int, + ) -> torch.Tensor: + """Return per-head attention weights from the target encoder layer. + + Runs the embedding preparation and all encoder layers up to and + including ``layer``, extracting the self-attention weights at that + layer via ``MultiheadAttention`` with ``need_weights=True``. + + Parameters + ---------- + input_ids : torch.Tensor + Token IDs of shape ``(batch, seq_len)``. + attention_mask : torch.Tensor + Boolean mask of shape ``(batch, seq_len)`` — ``True`` for real + tokens. + layer : int + Target layer index (negative values count from the last layer). + + Returns + ------- + torch.Tensor + Attention weights of shape ``(batch, n_heads, seq_len, seq_len)``. + """ + bert = self.model.bert + layer_idx = bert.config.nlayers + layer if layer < 0 else layer + + token_embedding = bert.embeddings(input_ids) + if bert.config.learnable_pe: + pos_embedding = bert.positional_embedding( + bert.pos.to(token_embedding.device) + ) + x = bert.dropout(token_embedding + pos_embedding) + else: + x = bert.positional_embedding(token_embedding) + + padding_mask = ~attention_mask.bool() + + for i in range(layer_idx + 1): + if i == layer_idx: + x_in = x + x = bert.encoder.layers[i]( + x, src_key_padding_mask=padding_mask, is_causal=False + ) + + enc_layer = bert.encoder.layers[layer_idx] + query = enc_layer.norm1(x_in) if enc_layer.norm_first else x_in + _, attn_weights = enc_layer.self_attn( + query, + query, + query, + key_padding_mask=padding_mask, + need_weights=True, + average_attn_weights=False, + ) + return attn_weights diff --git a/pyproject.toml b/pyproject.toml index 9cb9e76d..73e728b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "helical" -version = "1.10.0" +version = "1.10.1" authors = [ { name="Helical Team", email="support@helical-ai.com" }, ] diff --git a/requirements-dev.txt b/requirements-dev.txt index 3524599b..72399076 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,5 +2,5 @@ pytest==8.2.0 pytest-cov==5.0.0 pytest-mock==3.14.0 nbmake==1.5.4 -black==26.1.0 +black==26.3.1