diff --git a/ci/tests/test_nicheformer/test_nicheformer_model.py b/ci/tests/test_nicheformer/test_nicheformer_model.py index b3d49b51..dd8f199f 100644 --- a/ci/tests/test_nicheformer/test_nicheformer_model.py +++ b/ci/tests/test_nicheformer/test_nicheformer_model.py @@ -69,7 +69,7 @@ def _tokenize(adata, **kwargs): "ENSG00000000003": 32, } mocker.patch( - "helical.models.nicheformer.model.AutoTokenizer.from_pretrained", + "helical.models.nicheformer.model.NicheformerTokenizer.from_pretrained", return_value=mock_tokenizer, ) @@ -81,7 +81,7 @@ def _get_embeddings(input_ids, attention_mask, layer, with_context): mock_model.get_embeddings.side_effect = _get_embeddings mock_model.to.return_value = mock_model mocker.patch( - "helical.models.nicheformer.model.AutoModelForMaskedLM.from_pretrained", + "helical.models.nicheformer.model.NicheformerForMaskedLM.from_pretrained", return_value=mock_model, ) @@ -199,7 +199,7 @@ def test_with_context_forwarded_to_model(self, nicheformer, mock_adata, _mocks): assert mock_model.get_embeddings.call_args.kwargs["with_context"] is True def test_attention_shape_invariant_to_masking(self, nicheformer, stub_bert): - nicheformer.model.bert = stub_bert + nicheformer.model.nicheformer = stub_bert n_obs = 2 input_ids = torch.zeros((n_obs, _STUB_SEQ_LEN), dtype=torch.long) diff --git a/ci/tests/test_nicheformer/test_nicheformer_tokens.py b/ci/tests/test_nicheformer/test_nicheformer_tokens.py new file mode 100644 index 00000000..7288b69b --- /dev/null +++ b/ci/tests/test_nicheformer/test_nicheformer_tokens.py @@ -0,0 +1,192 @@ +""" +Tests for PAD / MASK token contract. + +The invariant, confirmed by the embedding layer's padding_idx and masking.py: + PAD_TOKEN = 1 – excluded from attention, embedding always zero + MASK_TOKEN = 0 – included in attention, model must predict the original + CLS_TOKEN = 2 + +These tests are pure unit tests: no disk access, no model downloads. +""" + +import numpy as np +import torch +import pytest + +from helical.models.nicheformer.tokenization_nicheformer import ( + PAD_TOKEN, + MASK_TOKEN, + CLS_TOKEN, + _sub_tokenize_data, +) +import helical.models.nicheformer.masking as masking_module +from helical.models.nicheformer.masking import complete_masking +from helical.models.nicheformer.modeling_nicheformer import NicheformerModel +from helical.models.nicheformer.configuration_nicheformer import ( + NicheformerConfig as ModelConfig, +) + + +# --------------------------------------------------------------------------- +# 1. Token ID consistency across all three modules +# --------------------------------------------------------------------------- + + +class TestTokenIdConsistency: + """All three modules must agree on the numeric value of each special token.""" + + def test_pad_token_value(self): + assert PAD_TOKEN == 1 + + def test_mask_token_value(self): + assert MASK_TOKEN == 0 + + def test_cls_token_value(self): + assert CLS_TOKEN == 2 + + def test_pad_matches_masking_module(self): + assert PAD_TOKEN == masking_module.PAD_TOKEN + + def test_mask_matches_masking_module(self): + assert MASK_TOKEN == masking_module.MASK_TOKEN + + def test_cls_matches_masking_module(self): + assert CLS_TOKEN == masking_module.CLS_TOKEN + + def test_embedding_padding_idx_matches_pad_token(self): + """The embedding layer's padding_idx must equal PAD_TOKEN so that + padding positions receive a zero embedding during both training and + inference.""" + cfg = ModelConfig(n_tokens=100, context_length=50, learnable_pe=True) + model = NicheformerModel(cfg) + assert model.embeddings.padding_idx == PAD_TOKEN + + def test_pad_embedding_is_zero(self): + """Consequence of padding_idx: the PAD row in the embedding table must + be the zero vector.""" + cfg = ModelConfig(n_tokens=100, context_length=50, learnable_pe=True) + model = NicheformerModel(cfg) + pad_emb = model.embeddings(torch.tensor([PAD_TOKEN])) + assert torch.all(pad_emb == 0), "PAD embedding is not zero" + + def test_mask_embedding_is_nonzero(self): + """MASK is a learned token; its embedding must not be the zero vector + after initialisation (xavier_normal_ will not produce exactly zeros).""" + cfg = ModelConfig(n_tokens=100, context_length=50, learnable_pe=True) + model = NicheformerModel(cfg) + mask_emb = model.embeddings(torch.tensor([MASK_TOKEN])) + assert not torch.all(mask_emb == 0), "MASK embedding is unexpectedly zero" + + +# --------------------------------------------------------------------------- +# 2. Tokenizer: trailing padding must carry PAD_TOKEN +# --------------------------------------------------------------------------- + + +class TestTokenizerPadding: + """_sub_tokenize_data must fill unused trailing positions with PAD_TOKEN.""" + + def _tokenize(self, x, max_seq_len=10, aux_tokens=30): + return _sub_tokenize_data( + x, max_seq_len=max_seq_len, aux_tokens=aux_tokens + ).astype(np.int32) + + def test_sparse_cell_trailing_positions_are_pad(self): + # 1 cell, 5 genes, only 2 non-zero → 8 trailing slots must be PAD + x = np.zeros((1, 5), dtype=np.float32) + x[0, 1] = 3.0 + x[0, 3] = 1.0 + tokens = self._tokenize(x) + assert (tokens[0, 2:] == PAD_TOKEN).all() + + def test_all_zero_cell_is_all_pad(self): + x = np.zeros((1, 5), dtype=np.float32) + tokens = self._tokenize(x) + assert (tokens[0] == PAD_TOKEN).all() + + def test_trailing_pad_is_not_mask_token(self): + x = np.zeros((1, 5), dtype=np.float32) + tokens = self._tokenize(x) + assert not (tokens[0] == MASK_TOKEN).any() + + +# --------------------------------------------------------------------------- +# 3. Attention mask: PAD excluded, MASK included +# --------------------------------------------------------------------------- + + +class TestAttentionMask: + """The attention mask derived from token IDs must treat PAD and MASK + oppositely: PAD → False (excluded), MASK → True (included).""" + + def _make_mask(self, token_ids: np.ndarray) -> np.ndarray: + # mirrors the tokenizer's attention_mask construction + return token_ids != PAD_TOKEN + + def test_pad_positions_masked_out(self): + ids = np.array([[30, 31, PAD_TOKEN, PAD_TOKEN]]) + mask = self._make_mask(ids) + assert not mask[0, 2] + assert not mask[0, 3] + + def test_gene_positions_attended(self): + ids = np.array([[30, 31, PAD_TOKEN, PAD_TOKEN]]) + mask = self._make_mask(ids) + assert mask[0, 0] + assert mask[0, 1] + + def test_mask_token_is_attended(self): + """A MASK token (value 0) placed in the sequence must be attended to, + not silently treated as padding.""" + ids = np.array([[MASK_TOKEN, 30, PAD_TOKEN]]) + mask = self._make_mask(ids) + assert mask[0, 0], "MASK token must not be excluded from attention" + assert not mask[0, 2], "PAD token must be excluded from attention" + + +# --------------------------------------------------------------------------- +# 4. complete_masking: PAD never masked, MASK token used for substitution +# --------------------------------------------------------------------------- + + +class TestMaskingBehavior: + """complete_masking must respect the PAD/MASK contract.""" + + def _batch(self, seq): + ids = torch.tensor([seq], dtype=torch.long) + return {"input_ids": ids, "attention_mask": ids != PAD_TOKEN} + + def test_pad_positions_never_replaced(self): + seq = [30, PAD_TOKEN, 31, PAD_TOKEN, PAD_TOKEN] + result = complete_masking(self._batch(seq), masking_p=1.0, n_tokens=100) + for i, tok in enumerate(seq): + if tok == PAD_TOKEN: + assert result["masked_indices"][0, i].item() == PAD_TOKEN + + def test_masked_positions_are_not_pad(self): + """Replacing a real token must never produce PAD_TOKEN (1).""" + seq = list(range(30, 50)) # 20 gene tokens, no padding + torch.manual_seed(0) + result = complete_masking(self._batch(seq), masking_p=1.0, n_tokens=200) + masked_pos = result["mask"][0] + for i in range(len(seq)): + if masked_pos[i]: + val = result["masked_indices"][0, i].item() + assert val != PAD_TOKEN, f"Position {i} was masked to PAD_TOKEN" + + def test_random_replacement_is_not_a_noop(self): + """The 10 % random-replacement branch must actually write to the tensor. + With a long sequence and p=1.0, ~10 % of tokens become random gene + tokens (≠ original and ≠ MASK_TOKEN). A no-op bug produces zero such + positions.""" + seq = list(range(30, 130)) # 100 gene tokens + torch.manual_seed(42) + result = complete_masking(self._batch(seq), masking_p=1.0, n_tokens=200) + original = result["input_ids"][0] + modified = result["masked_indices"][0] + randomly_replaced = ( + (modified != original) & (modified != MASK_TOKEN) & (modified != PAD_TOKEN) + ) + assert ( + randomly_replaced.any() + ), "No random token replacements found — the assignment is likely a no-op" diff --git a/helical/models/nicheformer/LICENSE b/helical/models/nicheformer/LICENSE index 72c1b6c1..3c32891e 100644 --- a/helical/models/nicheformer/LICENSE +++ b/helical/models/nicheformer/LICENSE @@ -1,29 +1,21 @@ -BSD 3-Clause License +MIT License -Copyright (c) 2024, Theislab -All rights reserved. +Copyright (c) 2024 Nicheformer Contributors -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: -1. Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. -2. Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - -3. Neither the name of the copyright holder nor the names of its - contributors may be used to endorse or promote products derived from - this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/helical/models/nicheformer/__init__.py b/helical/models/nicheformer/__init__.py index 4bb2745b..0ea2c8d8 100644 --- a/helical/models/nicheformer/__init__.py +++ b/helical/models/nicheformer/__init__.py @@ -1,2 +1,7 @@ from .model import Nicheformer from .nicheformer_config import NicheformerConfig +from .modeling_nicheformer import ( + NicheformerPreTrainedModel, + NicheformerModel, + NicheformerForMaskedLM, +) diff --git a/helical/models/nicheformer/configuration_nicheformer.py b/helical/models/nicheformer/configuration_nicheformer.py new file mode 100644 index 00000000..bafdcdab --- /dev/null +++ b/helical/models/nicheformer/configuration_nicheformer.py @@ -0,0 +1,61 @@ +from transformers import PretrainedConfig + + +class NicheformerConfig(PretrainedConfig): + model_type = "nicheformer" + + def __init__( + self, + dim_model=512, + nheads=16, + dim_feedforward=1024, + nlayers=12, + dropout=0.0, + batch_first=True, + masking_p=0.15, + n_tokens=20340, + context_length=1500, + cls_classes=164, + supervised_task=None, + learnable_pe=True, + specie=True, + assay=True, + modality=True, + **kwargs, + ): + """Initialize NicheformerConfig. + + Args: + dim_model: Dimensionality of the model + nheads: Number of attention heads + dim_feedforward: Dimensionality of MLPs in attention blocks + nlayers: Number of transformer layers + dropout: Dropout probability + batch_first: Whether batch dimension is first + masking_p: Probability of masking tokens + n_tokens: Total number of tokens (excluding auxiliary) + context_length: Length of the context window + cls_classes: Number of classification classes + supervised_task: Type of supervised task + learnable_pe: Whether to use learnable positional embeddings + specie: Whether to add specie token + assay: Whether to add assay token + modality: Whether to add modality token + """ + super().__init__(**kwargs) + + self.dim_model = dim_model + self.nheads = nheads + self.dim_feedforward = dim_feedforward + self.nlayers = nlayers + self.dropout = dropout + self.batch_first = batch_first + self.masking_p = masking_p + self.n_tokens = n_tokens + self.context_length = context_length + self.cls_classes = cls_classes + self.supervised_task = supervised_task + self.learnable_pe = learnable_pe + self.specie = specie + self.assay = assay + self.modality = modality diff --git a/helical/models/nicheformer/masking.py b/helical/models/nicheformer/masking.py new file mode 100644 index 00000000..d575b49c --- /dev/null +++ b/helical/models/nicheformer/masking.py @@ -0,0 +1,69 @@ +import torch +import random + +MASK_TOKEN = 0 +PAD_TOKEN = 1 +CLS_TOKEN = 2 + + +def complete_masking(batch, masking_p, n_tokens): + """Apply masking to input batch for masked language modeling. + + Args: + batch (dict): Input batch containing 'input_ids' and 'attention_mask' + masking_p (float): Probability of masking a token + n_tokens (int): Total number of tokens in vocabulary + + Returns: + dict: Batch with masked indices and masking information + """ + device = batch["input_ids"].device + input_ids = batch["input_ids"] + attention_mask = batch["attention_mask"] + + # Create mask tensor (1 for tokens to be masked, 0 otherwise) + prob = torch.rand(input_ids.shape, device=device) + mask = (prob < masking_p) & (input_ids != PAD_TOKEN) & (input_ids != CLS_TOKEN) + + # For masked tokens: + # - 80% replace with MASK token + # - 10% replace with random token + # - 10% keep unchanged + masked_indices = input_ids.clone() + + # Calculate number of tokens to be masked + num_tokens_to_mask = mask.sum().item() + + # Determine which tokens get which type of masking + mask_mask = torch.rand(num_tokens_to_mask, device=device) < 0.8 + random_mask = (torch.rand(num_tokens_to_mask, device=device) < 0.5) & ~mask_mask + + # Apply MASK token (80% of masked tokens) + masked_indices[mask] = torch.where( + mask_mask, + torch.tensor(MASK_TOKEN, device=device, dtype=torch.long), + masked_indices[mask], + ) + + # Apply random tokens (10% of masked tokens). + # Build a full-shape boolean mask first; single-level boolean assignment + # modifies the tensor in-place correctly (double-index does not). + random_tokens = torch.randint( + 3, + n_tokens, # Start from 3 to avoid special tokens + (random_mask.sum(),), + device=device, + dtype=torch.long, + ) + random_full_mask = torch.zeros_like(input_ids, dtype=torch.bool) + random_full_mask[mask] = random_mask + masked_indices[random_full_mask] = random_tokens + + # 10% remain unchanged + + return { + "masked_indices": masked_indices, + "attention_mask": attention_mask, + "mask": mask, + "input_ids": input_ids, + } diff --git a/helical/models/nicheformer/model.py b/helical/models/nicheformer/model.py index e848a6fc..8ef679f7 100644 --- a/helical/models/nicheformer/model.py +++ b/helical/models/nicheformer/model.py @@ -4,7 +4,10 @@ from helical.utils.mapping import map_gene_symbols_to_ensembl_ids from anndata import AnnData from datasets import Dataset -from transformers import AutoModelForMaskedLM, AutoTokenizer +from helical.models.nicheformer.modeling_nicheformer import NicheformerForMaskedLM +from helical.models.nicheformer.tokenization_nicheformer import ( + NicheformerTokenizer, +) import numpy as np import torch import logging @@ -75,18 +78,13 @@ def __init__(self, configurer: NicheformerConfig = default_configurer) -> None: model_files_dir = str(self.files_config["model_files_dir"]) - self.tokenizer = AutoTokenizer.from_pretrained( - model_files_dir, trust_remote_code=True - ) - self.tokenizer.name_or_path = self.config["model_name"] + self.tokenizer = NicheformerTokenizer.from_pretrained(model_files_dir) technology_mean = self.config["technology_mean"] if technology_mean is not None: self.tokenizer._load_technology_mean(technology_mean) - self.model = AutoModelForMaskedLM.from_pretrained( - model_files_dir, trust_remote_code=True - ) + self.model = NicheformerForMaskedLM.from_pretrained(model_files_dir) self.model.eval() self.model.to(self.device) @@ -260,28 +258,28 @@ def _extract_attention_weights( torch.Tensor Attention weights of shape ``(batch, n_heads, seq_len, seq_len)``. """ - bert = self.model.bert - layer_idx = bert.config.nlayers + layer if layer < 0 else layer + nicheformer = self.model.nicheformer + layer_idx = nicheformer.config.nlayers + layer if layer < 0 else layer - token_embedding = bert.embeddings(input_ids) - if bert.config.learnable_pe: - pos_embedding = bert.positional_embedding( - bert.pos.to(token_embedding.device) + token_embedding = nicheformer.embeddings(input_ids) + if nicheformer.config.learnable_pe: + pos_embedding = nicheformer.positional_embedding( + nicheformer.pos.to(token_embedding.device) ) - x = bert.dropout(token_embedding + pos_embedding) + x = nicheformer.dropout(token_embedding + pos_embedding) else: - x = bert.positional_embedding(token_embedding) + x = nicheformer.positional_embedding(token_embedding) padding_mask = ~attention_mask.bool() for i in range(layer_idx + 1): if i == layer_idx: x_in = x - x = bert.encoder.layers[i]( + x = nicheformer.encoder.layers[i]( x, src_key_padding_mask=padding_mask, is_causal=False ) - enc_layer = bert.encoder.layers[layer_idx] + enc_layer = nicheformer.encoder.layers[layer_idx] query = enc_layer.norm1(x_in) if enc_layer.norm_first else x_in _, attn_weights = enc_layer.self_attn( query, diff --git a/helical/models/nicheformer/modeling_nicheformer.py b/helical/models/nicheformer/modeling_nicheformer.py new file mode 100644 index 00000000..2850fcf7 --- /dev/null +++ b/helical/models/nicheformer/modeling_nicheformer.py @@ -0,0 +1,253 @@ +import torch +import torch.nn as nn +from transformers import PreTrainedModel +from transformers.modeling_outputs import MaskedLMOutput +from .configuration_nicheformer import NicheformerConfig +from .masking import complete_masking +import math + + +class PositionalEncoding(nn.Module): + """Positional encoding using sine and cosine functions.""" + + def __init__(self, d_model: int, max_seq_len: int): + super().__init__() + encoding = torch.zeros(max_seq_len, d_model) + position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) + ) + + encoding[:, 0::2] = torch.sin(position * div_term) + encoding[:, 1::2] = torch.cos(position * div_term) + encoding = encoding.unsqueeze(0) + + self.register_buffer("encoding", encoding, persistent=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Add positional encoding to input tensor.""" + return x + self.encoding[:, : x.size(1)] + + +class NicheformerPreTrainedModel(PreTrainedModel): + """Base class for Nicheformer models.""" + + config_class = NicheformerConfig + base_model_prefix = "nicheformer" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + nn.init.xavier_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +class NicheformerModel(NicheformerPreTrainedModel): + def __init__(self, config: NicheformerConfig): + super().__init__(config) + + # Core transformer components + self.encoder_layer = nn.TransformerEncoderLayer( + d_model=config.dim_model, + nhead=config.nheads, + dim_feedforward=config.dim_feedforward, + batch_first=config.batch_first, + dropout=config.dropout, + layer_norm_eps=1e-12, + ) + self.encoder = nn.TransformerEncoder( + encoder_layer=self.encoder_layer, + num_layers=config.nlayers, + enable_nested_tensor=False, + ) + + # Embedding layers + self.embeddings = nn.Embedding( + num_embeddings=config.n_tokens + 5, + embedding_dim=config.dim_model, + padding_idx=1, + ) + + if config.learnable_pe: + self.positional_embedding = nn.Embedding( + num_embeddings=config.context_length, embedding_dim=config.dim_model + ) + self.dropout = nn.Dropout(p=config.dropout) + self.register_buffer( + "pos", torch.arange(0, config.context_length, dtype=torch.long) + ) + else: + self.positional_embedding = PositionalEncoding( + d_model=config.dim_model, max_seq_len=config.context_length + ) + + # Initialize weights + self.post_init() + + def forward(self, input_ids, attention_mask=None): + token_embedding = self.embeddings(input_ids) + + if self.config.learnable_pe: + pos_embedding = self.positional_embedding( + self.pos.to(token_embedding.device) + ) + embeddings = self.dropout(token_embedding + pos_embedding) + else: + embeddings = self.positional_embedding(token_embedding) + + # Convert attention_mask to boolean and invert it for transformer's src_key_padding_mask + # True indicates positions that will be masked + if attention_mask is not None: + attention_mask = ~attention_mask.bool() + + transformer_output = self.encoder( + embeddings, + src_key_padding_mask=attention_mask if attention_mask is not None else None, + is_causal=False, + ) + + return transformer_output + + def get_embeddings( + self, + input_ids, + attention_mask=None, + layer: int = -1, + with_context: bool = False, + ) -> torch.Tensor: + """Get embeddings from the model. + + Args: + input_ids: Input token IDs + attention_mask: Attention mask + layer: Which transformer layer to extract embeddings from (-1 means last layer) + with_context: Whether to include context tokens in the embeddings + + Returns: + torch.Tensor: Embeddings tensor + """ + # Get token embeddings and positional encodings + token_embedding = self.embeddings(input_ids) + + if self.config.learnable_pe: + pos_embedding = self.positional_embedding( + self.pos.to(token_embedding.device) + ) + embeddings = self.dropout(token_embedding + pos_embedding) + else: + embeddings = self.positional_embedding(token_embedding) + + # Process through transformer layers up to desired layer + if layer < 0: + layer = self.config.nlayers + layer # -1 means last layer + + # Convert attention_mask to boolean and invert it for transformer's src_key_padding_mask + if attention_mask is not None: + padding_mask = ~attention_mask.bool() + else: + padding_mask = None + + # Process through each layer up to the desired one + for i in range(layer + 1): + embeddings = self.encoder.layers[i]( + embeddings, src_key_padding_mask=padding_mask, is_causal=False + ) + + # Remove context tokens (first 3 tokens) if not needed + if not with_context: + embeddings = embeddings[:, 3:, :] + + # Mean pooling over sequence dimension + embeddings = embeddings.mean(dim=1) + + return embeddings + + +class NicheformerForMaskedLM(NicheformerPreTrainedModel): + def __init__(self, config: NicheformerConfig): + super().__init__(config) + + self.nicheformer = NicheformerModel(config) + self.classifier_head = nn.Linear(config.dim_model, config.n_tokens, bias=False) + self.classifier_head.bias = nn.Parameter(torch.zeros(config.n_tokens)) + + # Initialize weights + self.post_init() + + def forward( + self, + input_ids=None, + attention_mask=None, + labels=None, + return_dict=None, + apply_masking=False, + ): + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # Apply masking if requested (typically during training) + if apply_masking: + batch = {"input_ids": input_ids, "attention_mask": attention_mask} + masked_batch = complete_masking( + batch, self.config.masking_p, self.config.n_tokens + ) + input_ids = masked_batch["masked_indices"] + labels = masked_batch["input_ids"] # Original tokens become labels + mask = masked_batch["mask"] + # Only compute loss on masked tokens and ensure labels are long + labels = torch.where( + mask, labels, torch.tensor(-100, device=labels.device) + ).long() + + transformer_output = self.nicheformer( + input_ids=input_ids, + attention_mask=attention_mask, + ) + + prediction_scores = self.classifier_head(transformer_output) + + masked_lm_loss = None + if labels is not None: + loss_fct = nn.CrossEntropyLoss() + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.n_tokens), labels.view(-1) + ) + + if not return_dict: + output = (prediction_scores,) + (transformer_output,) + return ( + ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + ) + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=transformer_output, + ) + + def get_embeddings( + self, + input_ids, + attention_mask=None, + layer: int = -1, + with_context: bool = False, + ) -> torch.Tensor: + """Get embeddings from the model. + + Args: + input_ids: Input token IDs + attention_mask: Attention mask + layer: Which transformer layer to extract embeddings from (-1 means last layer) + with_context: Whether to include context tokens in the embeddings + + Returns: + torch.Tensor: Embeddings tensor + """ + return self.nicheformer.get_embeddings( + input_ids=input_ids, + attention_mask=attention_mask, + layer=layer, + with_context=with_context, + ) diff --git a/helical/models/nicheformer/nicheformer_config.py b/helical/models/nicheformer/nicheformer_config.py index 1d5073dd..364ecd31 100644 --- a/helical/models/nicheformer/nicheformer_config.py +++ b/helical/models/nicheformer/nicheformer_config.py @@ -14,11 +14,6 @@ "vocab.json", "model.safetensors", "model.h5ad", - "modeling_nicheformer.py", - "tokenization_nicheformer.py", - "configuration_nicheformer.py", - "masking.py", - "__init__.py", ] diff --git a/helical/models/nicheformer/tokenization_nicheformer.py b/helical/models/nicheformer/tokenization_nicheformer.py new file mode 100644 index 00000000..67efbff7 --- /dev/null +++ b/helical/models/nicheformer/tokenization_nicheformer.py @@ -0,0 +1,457 @@ +from typing import List, Dict, Optional, Union, Tuple +import numpy as np +from transformers import PreTrainedTokenizer +from dataclasses import dataclass +import torch +import anndata as ad +from scipy.sparse import issparse +import numba +import os +import json +import pandas as pd + +# Token IDs must match exactly with the original implementation. +# PAD=1 is confirmed by padding_idx=1 in the embedding layer and masking.py. +PAD_TOKEN = 1 +MASK_TOKEN = 0 +CLS_TOKEN = 2 + +# These mappings preserve the exact token IDs from the original implementation +MODALITY_DICT = { + "dissociated": 3, + "spatial": 4, +} + +SPECIES_DICT = { + "human": 5, + "Homo sapiens": 5, + "Mus musculus": 6, + "mouse": 6, +} + +TECHNOLOGY_DICT = { + "merfish": 7, + "MERFISH": 7, + "cosmx": 8, + "NanoString digital spatial profiling": 8, + "Xenium": 9, + "10x 5' v2": 10, + "10x 3' v3": 11, + "10x 3' v2": 12, + "10x 5' v1": 13, + "10x 3' v1": 14, + "10x 3' transcription profiling": 15, + "10x transcription profiling": 15, + "10x 5' transcription profiling": 16, + "CITE-seq": 17, + "Smart-seq v4": 18, +} + + +def sf_normalize(X: np.ndarray) -> np.ndarray: + """Size factor normalize to 10k counts.""" + X = X.copy() + counts = np.array(X.sum(axis=1)) + # avoid zero division error + counts += counts == 0.0 + # normalize to 10000 counts + scaling_factor = 10000.0 / counts + + if issparse(X): + from scipy.sparse import sparsefuncs + + sparsefuncs.inplace_row_scale(X, scaling_factor) + else: + np.multiply(X, scaling_factor.reshape((-1, 1)), out=X) + + return X + + +@numba.jit(nopython=True, nogil=True) +def _sub_tokenize_data( + x: np.ndarray, max_seq_len: int = -1, aux_tokens: int = 30 +) -> np.ndarray: + """Tokenize the input gene vector.""" + scores_final = np.empty( + (x.shape[0], max_seq_len if max_seq_len > 0 else x.shape[1]) + ) + for i, cell in enumerate(x): + nonzero_mask = np.nonzero(cell)[0] + sorted_indices = nonzero_mask[np.argsort(-cell[nonzero_mask])][:max_seq_len] + sorted_indices = sorted_indices + aux_tokens + if max_seq_len: + scores = np.ones(max_seq_len, dtype=np.int32) # 1 = PAD_TOKEN + else: + scores = np.ones(cell.shape[0], dtype=np.int32) # 1 = PAD_TOKEN + scores[: len(sorted_indices)] = sorted_indices.astype(np.int32) + scores_final[i, :] = scores + return scores_final + + +class NicheformerTokenizer(PreTrainedTokenizer): + """Tokenizer for Nicheformer that handles single-cell data.""" + + model_input_names = ["input_ids", "attention_mask"] + vocab_files_names = {"vocab_file": "vocab.json"} + + modality_dict = MODALITY_DICT + species_dict = SPECIES_DICT + technology_dict = TECHNOLOGY_DICT + + def _load_reference_model(self): + """Load reference model for gene alignment.""" + try: + model_path = os.path.join(self.name_or_path, "model.h5ad") + return ad.read_h5ad(model_path) + except Exception as e: + print(f"Warning: Could not load reference model: {e}") + return None + + def __init__( + self, + vocab_file=None, + max_length: int = 1500, + aux_tokens: int = 30, + median_counts_per_gene: Optional[np.ndarray] = None, + gene_names: Optional[List[str]] = None, + technology_mean: Optional[Union[str, np.ndarray]] = None, + **kwargs, + ): + # Initialize base vocabulary + self._vocabulary = { + "[PAD]": PAD_TOKEN, + "[MASK]": MASK_TOKEN, + "[CLS]": CLS_TOKEN, + } + + if vocab_file is not None: + with open(vocab_file, "r") as f: + self._vocabulary.update(json.load(f)) + else: + # Add modality tokens + for name, idx in self.modality_dict.items(): + self._vocabulary[f"[MODALITY_{name}]"] = idx + # Add species tokens + for name, idx in self.species_dict.items(): + if name in ["Homo sapiens", "Mus musculus"]: + continue # Skip redundant names + self._vocabulary[f"[SPECIES_{name}]"] = idx + # Add technology tokens + for name, idx in self.technology_dict.items(): + if name in ["MERFISH", "10x transcription profiling"]: + continue # Skip redundant names + clean_name = name.lower().replace(" ", "_").replace("'", "_") + self._vocabulary[f"[TECH_{clean_name}]"] = idx + + # Add gene tokens if provided + if gene_names is not None: + for i, gene in enumerate(gene_names): + self._vocabulary[gene] = i + aux_tokens + # Save vocabulary + os.makedirs("to_hf", exist_ok=True) + with open("to_hf/vocab.json", "w") as f: + json.dump(self._vocabulary, f, indent=4) + + super().__init__(**kwargs) + + self.max_length = max_length + self.aux_tokens = aux_tokens + self.median_counts_per_gene = median_counts_per_gene + self.gene_names = gene_names + self.name_or_path = kwargs.get("name_or_path", "aletlvl/Nicheformer") + + # Set up special token mappings + self._pad_token = "[PAD]" + self._mask_token = "[MASK]" + self._cls_token = "[CLS]" + + # Load technology mean if provided + self.technology_mean = None + if technology_mean is not None: + self._load_technology_mean(technology_mean) + + # Cache the reference model so it is not reloaded on every __call__ + self._reference_model_cache = self._load_reference_model() + + def _load_technology_mean(self, technology_mean): + """Load technology mean from file or array.""" + if isinstance(technology_mean, str): + try: + self.technology_mean = np.load(technology_mean) + print( + f"Loaded technology mean from {technology_mean} with shape {self.technology_mean.shape}" + ) + except Exception as e: + print( + f"Warning: Could not load technology mean from {technology_mean}: {e}" + ) + elif isinstance(technology_mean, np.ndarray): + self.technology_mean = technology_mean + print( + f"Using provided technology mean array with shape {self.technology_mean.shape}" + ) + else: + print(f"Warning: Invalid technology_mean type: {type(technology_mean)}") + + def get_vocab(self) -> Dict[str, int]: + """Returns the vocabulary mapping.""" + return self._vocabulary.copy() + + def _tokenize(self, text: str) -> List[str]: + """Tokenize text input.""" + # This tokenizer doesn't handle text input directly + raise NotImplementedError("This tokenizer only works with gene expression data") + + def _convert_token_to_id(self, token: str) -> int: + """Convert token to ID.""" + # First check special token mappings + if token in self.modality_dict: + return self.modality_dict[token] + if token in self.species_dict: + return self.species_dict[token] + if token in self.technology_dict: + return self.technology_dict[token] + # Then check vocabulary + return self._vocabulary.get(token, self._vocabulary["[PAD]"]) + + def _convert_id_to_token(self, index: int) -> str: + """Convert ID to token.""" + # First check special token mappings + for token, idx in self.modality_dict.items(): + if idx == index: + return token + for token, idx in self.species_dict.items(): + if idx == index: + return token + for token, idx in self.technology_dict.items(): + if idx == index: + return token + # Then check vocabulary + for token, idx in self._vocabulary.items(): + if idx == index: + return token + return "[PAD]" + + def save_vocabulary( + self, save_directory: str, filename_prefix: Optional[str] = None + ) -> Tuple[str]: + """Save the vocabulary to a file.""" + vocab_file = os.path.join( + save_directory, + (filename_prefix + "-" if filename_prefix else "") + "vocab.json", + ) + + with open(vocab_file, "w", encoding="utf-8") as f: + json.dump(self._vocabulary, f, ensure_ascii=False) + + return (vocab_file,) + + def _tokenize_gene_expression(self, x: np.ndarray) -> np.ndarray: + """Tokenize gene expression matrix. + + Args: + x: Gene expression matrix (cells x genes) + + Returns: + Tokenized matrix + """ + # Handle sparse input + if issparse(x): + x = x.toarray() + + # Normalize and scale + x = np.nan_to_num(x) + x = sf_normalize(x) + if self.median_counts_per_gene is not None: + median_counts = self.median_counts_per_gene.copy() + median_counts += median_counts == 0 + x = x / median_counts.reshape((1, -1)) + + # Apply technology mean normalization if available + if ( + self.technology_mean is not None + and self.technology_mean.shape[0] == x.shape[1] + ): + # Avoid division by zero + safe_mean = np.maximum(self.technology_mean, 1e-6) + x = x / safe_mean + + # Apply log1p transformation + x = np.log1p(x) + + # Convert to tokens + tokens = _sub_tokenize_data(x, self.max_length, self.aux_tokens) + + return tokens.astype(np.int32) + + def __call__( + self, data: Union[ad.AnnData, np.ndarray], **kwargs + ) -> Dict[str, torch.Tensor]: + """Tokenize gene expression data. + + Args: + data: AnnData object or numpy array of gene expression data + + Returns: + Dictionary with input_ids and attention_mask tensors + """ + if isinstance(data, ad.AnnData): + adata = data.copy() + + # Align with reference model if available + reference_model = self._reference_model_cache + if reference_model is not None: + # Store original column types before concatenation + original_types = {} + for col in ["modality", "specie", "assay"]: + if col in adata.obs.columns: + original_types[col] = adata.obs[col].dtype + + # Concatenate and then remove all reference cells + n_ref = len(reference_model) + adata = ad.concat([reference_model, adata], join="outer", axis=0) + adata = adata[n_ref:] + + # Restore original column types after concatenation + for col, dtype in original_types.items(): + if col in adata.obs.columns: + try: + adata.obs[col] = adata.obs[col].astype(dtype) + except Exception as e: + print( + f"Warning: Could not convert {col} back to {dtype}: {e}" + ) + + # Get gene expression data + X = adata.X + + # Get metadata for special tokens + modality = ( + adata.obs["modality"] if "modality" in adata.obs.columns else None + ) + species = adata.obs["specie"] if "specie" in adata.obs.columns else None + technology = adata.obs["assay"] if "assay" in adata.obs.columns else None + + # Use integer values directly if available + if modality is not None: + try: + if pd.api.types.is_numeric_dtype(modality): + modality_tokens = modality.astype(int).tolist() + else: + modality_tokens = [ + self.modality_dict.get(m, self._vocabulary["[PAD]"]) + for m in modality + ] + except Exception as e: + print(f"Warning: Error processing modality tokens: {e}") + modality_tokens = [self._vocabulary["[PAD]"]] * len(adata) + else: + modality_tokens = None + + if species is not None: + try: + if pd.api.types.is_numeric_dtype(species): + species_tokens = species.astype(int).tolist() + else: + species_tokens = [ + self.species_dict.get(s, self._vocabulary["[PAD]"]) + for s in species + ] + except Exception as e: + print(f"Warning: Error processing species tokens: {e}") + species_tokens = [self._vocabulary["[PAD]"]] * len(adata) + else: + species_tokens = None + + if technology is not None: + try: + if pd.api.types.is_numeric_dtype(technology): + technology_tokens = technology.astype(int).tolist() + else: + technology_tokens = [ + self.technology_dict.get(t, self._vocabulary["[PAD]"]) + for t in technology + ] + except Exception as e: + print(f"Warning: Error processing technology tokens: {e}") + technology_tokens = [self._vocabulary["[PAD]"]] * len(adata) + else: + technology_tokens = None + else: + X = data + modality_tokens = None + species_tokens = None + technology_tokens = None + + # Tokenize gene expression data + token_ids = self._tokenize_gene_expression(X) + + # Add special tokens if available - changed order to [species, technology, modality] + special_tokens = np.zeros((token_ids.shape[0], 3), dtype=np.int64) + special_token_mask = np.zeros((token_ids.shape[0], 3), dtype=bool) + + if species_tokens is not None: + special_tokens[:, 0] = species_tokens + special_token_mask[:, 0] = True + + if technology_tokens is not None: + special_tokens[:, 1] = technology_tokens + special_token_mask[:, 1] = True + + if modality_tokens is not None: + special_tokens[:, 2] = modality_tokens + special_token_mask[:, 2] = True + + # Only keep the special tokens that are present (have True in mask) + special_tokens = special_tokens[:, special_token_mask[0]] + + if special_tokens.size > 0: + token_ids = np.concatenate( + [ + special_tokens, + token_ids[:, : (self.max_length - special_tokens.shape[1])], + ], + axis=1, + ) + + # Create attention mask + attention_mask = token_ids != self._vocabulary["[PAD]"] + + return { + "input_ids": torch.tensor(token_ids, dtype=torch.long), + "attention_mask": torch.tensor(attention_mask), + } + + def get_vocab_size(self) -> int: + """Get vocabulary size.""" + if self.gene_names is not None: + return len(self.gene_names) + self.aux_tokens + return ( + max( + max(self.modality_dict.values()), + max(self.species_dict.values()), + max(self.technology_dict.values()), + ) + + 1 + ) + + def convert_tokens_to_string(self, tokens: List[str]) -> str: + """Convert a sequence of tokens to a string. Not used for gene expression.""" + raise NotImplementedError("This tokenizer only works with gene expression data") + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """Build model inputs from a sequence by adding special tokens.""" + # For gene expression data, special tokens are handled in __call__ + return token_ids_0 + + def get_special_tokens_mask( + self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None, + already_has_special_tokens: bool = False, + ) -> List[int]: + """Get list where entries are [1] if a token is [special] else [0].""" + # Consider tokens < aux_tokens as special + return [1 if token_id < self.aux_tokens else 0 for token_id in token_ids_0] diff --git a/pyproject.toml b/pyproject.toml index db8e1b52..11c56a0e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "helical" -version = "1.11.0" +version = "1.12.0" authors = [ { name="Helical Team", email="support@helical-ai.com" }, ]