Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions helical/models/nicheformer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,8 @@
from .model import Nicheformer
from .nicheformer_config import NicheformerConfig
from .configuration_nicheformer import NicheformerConfig
from .modeling_nicheformer import (
NicheformerPreTrainedModel,
NicheformerModel,
NicheformerForMaskedLM,
)
61 changes: 61 additions & 0 deletions helical/models/nicheformer/configuration_nicheformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from transformers import PretrainedConfig


class NicheformerConfig(PretrainedConfig):
model_type = "nicheformer"

def __init__(
self,
dim_model=512,
nheads=16,
dim_feedforward=1024,
nlayers=12,
dropout=0.0,
batch_first=True,
masking_p=0.15,
n_tokens=20340,
context_length=1500,
cls_classes=164,
supervised_task=None,
learnable_pe=True,
specie=True,
assay=True,
modality=True,
**kwargs,
):
"""Initialize NicheformerConfig.

Args:
dim_model: Dimensionality of the model
nheads: Number of attention heads
dim_feedforward: Dimensionality of MLPs in attention blocks
nlayers: Number of transformer layers
dropout: Dropout probability
batch_first: Whether batch dimension is first
masking_p: Probability of masking tokens
n_tokens: Total number of tokens (excluding auxiliary)
context_length: Length of the context window
cls_classes: Number of classification classes
supervised_task: Type of supervised task
learnable_pe: Whether to use learnable positional embeddings
specie: Whether to add specie token
assay: Whether to add assay token
modality: Whether to add modality token
"""
super().__init__(**kwargs)

self.dim_model = dim_model
self.nheads = nheads
self.dim_feedforward = dim_feedforward
self.nlayers = nlayers
self.dropout = dropout
self.batch_first = batch_first
self.masking_p = masking_p
self.n_tokens = n_tokens
self.context_length = context_length
self.cls_classes = cls_classes
self.supervised_task = supervised_task
self.learnable_pe = learnable_pe
self.specie = specie
self.assay = assay
self.modality = modality
65 changes: 65 additions & 0 deletions helical/models/nicheformer/masking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import torch
import random

MASK_TOKEN = 0
PAD_TOKEN = 1
CLS_TOKEN = 2


def complete_masking(batch, masking_p, n_tokens):
"""Apply masking to input batch for masked language modeling.

Args:
batch (dict): Input batch containing 'input_ids' and 'attention_mask'
masking_p (float): Probability of masking a token
n_tokens (int): Total number of tokens in vocabulary

Returns:
dict: Batch with masked indices and masking information
"""
device = batch["input_ids"].device
input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]

# Create mask tensor (1 for tokens to be masked, 0 otherwise)
prob = torch.rand(input_ids.shape, device=device)
mask = (prob < masking_p) & (input_ids != PAD_TOKEN) & (input_ids != CLS_TOKEN)

# For masked tokens:
# - 80% replace with MASK token
# - 10% replace with random token
# - 10% keep unchanged
masked_indices = input_ids.clone()

# Calculate number of tokens to be masked
num_tokens_to_mask = mask.sum().item()

# Determine which tokens get which type of masking
mask_mask = torch.rand(num_tokens_to_mask, device=device) < 0.8
random_mask = (torch.rand(num_tokens_to_mask, device=device) < 0.5) & ~mask_mask

# Apply MASK token (80% of masked tokens)
masked_indices[mask] = torch.where(
mask_mask,
torch.tensor(MASK_TOKEN, device=device, dtype=torch.long),
masked_indices[mask],
)

# Apply random tokens (10% of masked tokens)
random_tokens = torch.randint(
3,
n_tokens, # Start from 3 to avoid special tokens
(random_mask.sum(),),
device=device,
dtype=torch.long,
)
masked_indices[mask][random_mask] = random_tokens

# 10% remain unchanged

return {
"masked_indices": masked_indices,
"attention_mask": attention_mask,
"mask": mask,
"input_ids": input_ids,
}
13 changes: 8 additions & 5 deletions helical/models/nicheformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
from helical.utils.mapping import map_gene_symbols_to_ensembl_ids
from anndata import AnnData
from datasets import Dataset
from transformers import AutoModelForMaskedLM, AutoTokenizer
from helical.models.nicheformer.modeling_nicheformer import NicheformerForMaskedLM
from helical.models.nicheformer.tokenization_nicheformer import (
NicheformerTokenizer,
)
import numpy as np
import torch
import logging
Expand Down Expand Up @@ -75,17 +78,17 @@ def __init__(self, configurer: NicheformerConfig = default_configurer) -> None:

model_files_dir = str(self.files_config["model_files_dir"])

self.tokenizer = AutoTokenizer.from_pretrained(
model_files_dir, trust_remote_code=True
self.tokenizer = NicheformerTokenizer.from_pretrained(
model_files_dir
)
self.tokenizer.name_or_path = self.config["model_name"]

technology_mean = self.config["technology_mean"]
if technology_mean is not None:
self.tokenizer._load_technology_mean(technology_mean)

self.model = AutoModelForMaskedLM.from_pretrained(
model_files_dir, trust_remote_code=True
self.model = NicheformerForMaskedLM.from_pretrained(
model_files_dir
)
self.model.eval()
self.model.to(self.device)
Expand Down
Loading
Loading