Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions ci/tests/test_nicheformer/test_nicheformer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def _tokenize(adata, **kwargs):
"ENSG00000000003": 32,
}
mocker.patch(
"helical.models.nicheformer.model.AutoTokenizer.from_pretrained",
"helical.models.nicheformer.model.NicheformerTokenizer.from_pretrained",
return_value=mock_tokenizer,
)

Expand All @@ -81,7 +81,7 @@ def _get_embeddings(input_ids, attention_mask, layer, with_context):
mock_model.get_embeddings.side_effect = _get_embeddings
mock_model.to.return_value = mock_model
mocker.patch(
"helical.models.nicheformer.model.AutoModelForMaskedLM.from_pretrained",
"helical.models.nicheformer.model.NicheformerForMaskedLM.from_pretrained",
return_value=mock_model,
)

Expand Down Expand Up @@ -199,7 +199,7 @@ def test_with_context_forwarded_to_model(self, nicheformer, mock_adata, _mocks):
assert mock_model.get_embeddings.call_args.kwargs["with_context"] is True

def test_attention_shape_invariant_to_masking(self, nicheformer, stub_bert):
nicheformer.model.bert = stub_bert
nicheformer.model.nicheformer = stub_bert
n_obs = 2
input_ids = torch.zeros((n_obs, _STUB_SEQ_LEN), dtype=torch.long)

Expand Down
192 changes: 192 additions & 0 deletions ci/tests/test_nicheformer/test_nicheformer_tokens.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
"""
Tests for PAD / MASK token contract.

The invariant, confirmed by the embedding layer's padding_idx and masking.py:
PAD_TOKEN = 1 – excluded from attention, embedding always zero
MASK_TOKEN = 0 – included in attention, model must predict the original
CLS_TOKEN = 2

These tests are pure unit tests: no disk access, no model downloads.
"""

import numpy as np
import torch
import pytest

from helical.models.nicheformer.tokenization_nicheformer import (
PAD_TOKEN,
MASK_TOKEN,
CLS_TOKEN,
_sub_tokenize_data,
)
import helical.models.nicheformer.masking as masking_module
from helical.models.nicheformer.masking import complete_masking
from helical.models.nicheformer.modeling_nicheformer import NicheformerModel
from helical.models.nicheformer.configuration_nicheformer import (
NicheformerConfig as ModelConfig,
)


# ---------------------------------------------------------------------------
# 1. Token ID consistency across all three modules
# ---------------------------------------------------------------------------


class TestTokenIdConsistency:
"""All three modules must agree on the numeric value of each special token."""

def test_pad_token_value(self):
assert PAD_TOKEN == 1

def test_mask_token_value(self):
assert MASK_TOKEN == 0

def test_cls_token_value(self):
assert CLS_TOKEN == 2

def test_pad_matches_masking_module(self):
assert PAD_TOKEN == masking_module.PAD_TOKEN

def test_mask_matches_masking_module(self):
assert MASK_TOKEN == masking_module.MASK_TOKEN

def test_cls_matches_masking_module(self):
assert CLS_TOKEN == masking_module.CLS_TOKEN

def test_embedding_padding_idx_matches_pad_token(self):
"""The embedding layer's padding_idx must equal PAD_TOKEN so that
padding positions receive a zero embedding during both training and
inference."""
cfg = ModelConfig(n_tokens=100, context_length=50, learnable_pe=True)
model = NicheformerModel(cfg)
assert model.embeddings.padding_idx == PAD_TOKEN

def test_pad_embedding_is_zero(self):
"""Consequence of padding_idx: the PAD row in the embedding table must
be the zero vector."""
cfg = ModelConfig(n_tokens=100, context_length=50, learnable_pe=True)
model = NicheformerModel(cfg)
pad_emb = model.embeddings(torch.tensor([PAD_TOKEN]))
assert torch.all(pad_emb == 0), "PAD embedding is not zero"

def test_mask_embedding_is_nonzero(self):
"""MASK is a learned token; its embedding must not be the zero vector
after initialisation (xavier_normal_ will not produce exactly zeros)."""
cfg = ModelConfig(n_tokens=100, context_length=50, learnable_pe=True)
model = NicheformerModel(cfg)
mask_emb = model.embeddings(torch.tensor([MASK_TOKEN]))
assert not torch.all(mask_emb == 0), "MASK embedding is unexpectedly zero"


# ---------------------------------------------------------------------------
# 2. Tokenizer: trailing padding must carry PAD_TOKEN
# ---------------------------------------------------------------------------


class TestTokenizerPadding:
"""_sub_tokenize_data must fill unused trailing positions with PAD_TOKEN."""

def _tokenize(self, x, max_seq_len=10, aux_tokens=30):
return _sub_tokenize_data(
x, max_seq_len=max_seq_len, aux_tokens=aux_tokens
).astype(np.int32)

def test_sparse_cell_trailing_positions_are_pad(self):
# 1 cell, 5 genes, only 2 non-zero → 8 trailing slots must be PAD
x = np.zeros((1, 5), dtype=np.float32)
x[0, 1] = 3.0
x[0, 3] = 1.0
tokens = self._tokenize(x)
assert (tokens[0, 2:] == PAD_TOKEN).all()

def test_all_zero_cell_is_all_pad(self):
x = np.zeros((1, 5), dtype=np.float32)
tokens = self._tokenize(x)
assert (tokens[0] == PAD_TOKEN).all()

def test_trailing_pad_is_not_mask_token(self):
x = np.zeros((1, 5), dtype=np.float32)
tokens = self._tokenize(x)
assert not (tokens[0] == MASK_TOKEN).any()


# ---------------------------------------------------------------------------
# 3. Attention mask: PAD excluded, MASK included
# ---------------------------------------------------------------------------


class TestAttentionMask:
"""The attention mask derived from token IDs must treat PAD and MASK
oppositely: PAD → False (excluded), MASK → True (included)."""

def _make_mask(self, token_ids: np.ndarray) -> np.ndarray:
# mirrors the tokenizer's attention_mask construction
return token_ids != PAD_TOKEN

def test_pad_positions_masked_out(self):
ids = np.array([[30, 31, PAD_TOKEN, PAD_TOKEN]])
mask = self._make_mask(ids)
assert not mask[0, 2]
assert not mask[0, 3]

def test_gene_positions_attended(self):
ids = np.array([[30, 31, PAD_TOKEN, PAD_TOKEN]])
mask = self._make_mask(ids)
assert mask[0, 0]
assert mask[0, 1]

def test_mask_token_is_attended(self):
"""A MASK token (value 0) placed in the sequence must be attended to,
not silently treated as padding."""
ids = np.array([[MASK_TOKEN, 30, PAD_TOKEN]])
mask = self._make_mask(ids)
assert mask[0, 0], "MASK token must not be excluded from attention"
assert not mask[0, 2], "PAD token must be excluded from attention"


# ---------------------------------------------------------------------------
# 4. complete_masking: PAD never masked, MASK token used for substitution
# ---------------------------------------------------------------------------


class TestMaskingBehavior:
"""complete_masking must respect the PAD/MASK contract."""

def _batch(self, seq):
ids = torch.tensor([seq], dtype=torch.long)
return {"input_ids": ids, "attention_mask": ids != PAD_TOKEN}

def test_pad_positions_never_replaced(self):
seq = [30, PAD_TOKEN, 31, PAD_TOKEN, PAD_TOKEN]
result = complete_masking(self._batch(seq), masking_p=1.0, n_tokens=100)
for i, tok in enumerate(seq):
if tok == PAD_TOKEN:
assert result["masked_indices"][0, i].item() == PAD_TOKEN

def test_masked_positions_are_not_pad(self):
"""Replacing a real token must never produce PAD_TOKEN (1)."""
seq = list(range(30, 50)) # 20 gene tokens, no padding
torch.manual_seed(0)
result = complete_masking(self._batch(seq), masking_p=1.0, n_tokens=200)
masked_pos = result["mask"][0]
for i in range(len(seq)):
if masked_pos[i]:
val = result["masked_indices"][0, i].item()
assert val != PAD_TOKEN, f"Position {i} was masked to PAD_TOKEN"

def test_random_replacement_is_not_a_noop(self):
"""The 10 % random-replacement branch must actually write to the tensor.
With a long sequence and p=1.0, ~10 % of tokens become random gene
tokens (≠ original and ≠ MASK_TOKEN). A no-op bug produces zero such
positions."""
seq = list(range(30, 130)) # 100 gene tokens
torch.manual_seed(42)
result = complete_masking(self._batch(seq), masking_p=1.0, n_tokens=200)
original = result["input_ids"][0]
modified = result["masked_indices"][0]
randomly_replaced = (
(modified != original) & (modified != MASK_TOKEN) & (modified != PAD_TOKEN)
)
assert (
randomly_replaced.any()
), "No random token replacements found — the assignment is likely a no-op"
42 changes: 17 additions & 25 deletions helical/models/nicheformer/LICENSE
Original file line number Diff line number Diff line change
@@ -1,29 +1,21 @@
BSD 3-Clause License
MIT License

Copyright (c) 2024, Theislab
All rights reserved.
Copyright (c) 2024 Nicheformer Contributors

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.

3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
5 changes: 5 additions & 0 deletions helical/models/nicheformer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,7 @@
from .model import Nicheformer
from .nicheformer_config import NicheformerConfig
from .modeling_nicheformer import (
NicheformerPreTrainedModel,
NicheformerModel,
NicheformerForMaskedLM,
)
61 changes: 61 additions & 0 deletions helical/models/nicheformer/configuration_nicheformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from transformers import PretrainedConfig


class NicheformerConfig(PretrainedConfig):
model_type = "nicheformer"

def __init__(
self,
dim_model=512,
nheads=16,
dim_feedforward=1024,
nlayers=12,
dropout=0.0,
batch_first=True,
masking_p=0.15,
n_tokens=20340,
context_length=1500,
cls_classes=164,
supervised_task=None,
learnable_pe=True,
specie=True,
assay=True,
modality=True,
**kwargs,
):
"""Initialize NicheformerConfig.

Args:
dim_model: Dimensionality of the model
nheads: Number of attention heads
dim_feedforward: Dimensionality of MLPs in attention blocks
nlayers: Number of transformer layers
dropout: Dropout probability
batch_first: Whether batch dimension is first
masking_p: Probability of masking tokens
n_tokens: Total number of tokens (excluding auxiliary)
context_length: Length of the context window
cls_classes: Number of classification classes
supervised_task: Type of supervised task
learnable_pe: Whether to use learnable positional embeddings
specie: Whether to add specie token
assay: Whether to add assay token
modality: Whether to add modality token
"""
super().__init__(**kwargs)

self.dim_model = dim_model
self.nheads = nheads
self.dim_feedforward = dim_feedforward
self.nlayers = nlayers
self.dropout = dropout
self.batch_first = batch_first
self.masking_p = masking_p
self.n_tokens = n_tokens
self.context_length = context_length
self.cls_classes = cls_classes
self.supervised_task = supervised_task
self.learnable_pe = learnable_pe
self.specie = specie
self.assay = assay
self.modality = modality
Loading
Loading