From 9985d6e3355c5a15f60752eaa673bfcbf9166092 Mon Sep 17 00:00:00 2001 From: bputzeys Date: Tue, 24 Mar 2026 13:11:47 +0100 Subject: [PATCH] replaced AutoTokenizer/AutoModelForMaskedLM with direct imports of NicheformerTokenizer/NicheformerForMaskedLM, no trust_remote_code --- helical/models/nicheformer/__init__.py | 6 + .../nicheformer/configuration_nicheformer.py | 61 +++ helical/models/nicheformer/masking.py | 65 +++ helical/models/nicheformer/model.py | 13 +- .../nicheformer/modeling_nicheformer.py | 311 ++++++++++++ .../models/nicheformer/nicheformer_config.py | 5 - .../nicheformer/tokenization_nicheformer.py | 462 ++++++++++++++++++ pyproject.toml | 2 +- 8 files changed, 914 insertions(+), 11 deletions(-) create mode 100644 helical/models/nicheformer/configuration_nicheformer.py create mode 100644 helical/models/nicheformer/masking.py create mode 100644 helical/models/nicheformer/modeling_nicheformer.py create mode 100644 helical/models/nicheformer/tokenization_nicheformer.py diff --git a/helical/models/nicheformer/__init__.py b/helical/models/nicheformer/__init__.py index 4bb2745b..3939490f 100644 --- a/helical/models/nicheformer/__init__.py +++ b/helical/models/nicheformer/__init__.py @@ -1,2 +1,8 @@ from .model import Nicheformer from .nicheformer_config import NicheformerConfig +from .configuration_nicheformer import NicheformerConfig +from .modeling_nicheformer import ( + NicheformerPreTrainedModel, + NicheformerModel, + NicheformerForMaskedLM, +) diff --git a/helical/models/nicheformer/configuration_nicheformer.py b/helical/models/nicheformer/configuration_nicheformer.py new file mode 100644 index 00000000..bafdcdab --- /dev/null +++ b/helical/models/nicheformer/configuration_nicheformer.py @@ -0,0 +1,61 @@ +from transformers import PretrainedConfig + + +class NicheformerConfig(PretrainedConfig): + model_type = "nicheformer" + + def __init__( + self, + dim_model=512, + nheads=16, + dim_feedforward=1024, + nlayers=12, + dropout=0.0, + batch_first=True, + masking_p=0.15, + n_tokens=20340, + context_length=1500, + cls_classes=164, + supervised_task=None, + learnable_pe=True, + specie=True, + assay=True, + modality=True, + **kwargs, + ): + """Initialize NicheformerConfig. + + Args: + dim_model: Dimensionality of the model + nheads: Number of attention heads + dim_feedforward: Dimensionality of MLPs in attention blocks + nlayers: Number of transformer layers + dropout: Dropout probability + batch_first: Whether batch dimension is first + masking_p: Probability of masking tokens + n_tokens: Total number of tokens (excluding auxiliary) + context_length: Length of the context window + cls_classes: Number of classification classes + supervised_task: Type of supervised task + learnable_pe: Whether to use learnable positional embeddings + specie: Whether to add specie token + assay: Whether to add assay token + modality: Whether to add modality token + """ + super().__init__(**kwargs) + + self.dim_model = dim_model + self.nheads = nheads + self.dim_feedforward = dim_feedforward + self.nlayers = nlayers + self.dropout = dropout + self.batch_first = batch_first + self.masking_p = masking_p + self.n_tokens = n_tokens + self.context_length = context_length + self.cls_classes = cls_classes + self.supervised_task = supervised_task + self.learnable_pe = learnable_pe + self.specie = specie + self.assay = assay + self.modality = modality diff --git a/helical/models/nicheformer/masking.py b/helical/models/nicheformer/masking.py new file mode 100644 index 00000000..59535c55 --- /dev/null +++ b/helical/models/nicheformer/masking.py @@ -0,0 +1,65 @@ +import torch +import random + +MASK_TOKEN = 0 +PAD_TOKEN = 1 +CLS_TOKEN = 2 + + +def complete_masking(batch, masking_p, n_tokens): + """Apply masking to input batch for masked language modeling. + + Args: + batch (dict): Input batch containing 'input_ids' and 'attention_mask' + masking_p (float): Probability of masking a token + n_tokens (int): Total number of tokens in vocabulary + + Returns: + dict: Batch with masked indices and masking information + """ + device = batch["input_ids"].device + input_ids = batch["input_ids"] + attention_mask = batch["attention_mask"] + + # Create mask tensor (1 for tokens to be masked, 0 otherwise) + prob = torch.rand(input_ids.shape, device=device) + mask = (prob < masking_p) & (input_ids != PAD_TOKEN) & (input_ids != CLS_TOKEN) + + # For masked tokens: + # - 80% replace with MASK token + # - 10% replace with random token + # - 10% keep unchanged + masked_indices = input_ids.clone() + + # Calculate number of tokens to be masked + num_tokens_to_mask = mask.sum().item() + + # Determine which tokens get which type of masking + mask_mask = torch.rand(num_tokens_to_mask, device=device) < 0.8 + random_mask = (torch.rand(num_tokens_to_mask, device=device) < 0.5) & ~mask_mask + + # Apply MASK token (80% of masked tokens) + masked_indices[mask] = torch.where( + mask_mask, + torch.tensor(MASK_TOKEN, device=device, dtype=torch.long), + masked_indices[mask], + ) + + # Apply random tokens (10% of masked tokens) + random_tokens = torch.randint( + 3, + n_tokens, # Start from 3 to avoid special tokens + (random_mask.sum(),), + device=device, + dtype=torch.long, + ) + masked_indices[mask][random_mask] = random_tokens + + # 10% remain unchanged + + return { + "masked_indices": masked_indices, + "attention_mask": attention_mask, + "mask": mask, + "input_ids": input_ids, + } diff --git a/helical/models/nicheformer/model.py b/helical/models/nicheformer/model.py index e848a6fc..a0f5856d 100644 --- a/helical/models/nicheformer/model.py +++ b/helical/models/nicheformer/model.py @@ -4,7 +4,10 @@ from helical.utils.mapping import map_gene_symbols_to_ensembl_ids from anndata import AnnData from datasets import Dataset -from transformers import AutoModelForMaskedLM, AutoTokenizer +from helical.models.nicheformer.modeling_nicheformer import NicheformerForMaskedLM +from helical.models.nicheformer.tokenization_nicheformer import ( + NicheformerTokenizer, +) import numpy as np import torch import logging @@ -75,8 +78,8 @@ def __init__(self, configurer: NicheformerConfig = default_configurer) -> None: model_files_dir = str(self.files_config["model_files_dir"]) - self.tokenizer = AutoTokenizer.from_pretrained( - model_files_dir, trust_remote_code=True + self.tokenizer = NicheformerTokenizer.from_pretrained( + model_files_dir ) self.tokenizer.name_or_path = self.config["model_name"] @@ -84,8 +87,8 @@ def __init__(self, configurer: NicheformerConfig = default_configurer) -> None: if technology_mean is not None: self.tokenizer._load_technology_mean(technology_mean) - self.model = AutoModelForMaskedLM.from_pretrained( - model_files_dir, trust_remote_code=True + self.model = NicheformerForMaskedLM.from_pretrained( + model_files_dir ) self.model.eval() self.model.to(self.device) diff --git a/helical/models/nicheformer/modeling_nicheformer.py b/helical/models/nicheformer/modeling_nicheformer.py new file mode 100644 index 00000000..16b3e4f4 --- /dev/null +++ b/helical/models/nicheformer/modeling_nicheformer.py @@ -0,0 +1,311 @@ +import torch +import torch.nn as nn +from transformers import PreTrainedModel +from transformers.modeling_outputs import MaskedLMOutput +from .configuration_nicheformer import NicheformerConfig +import math + + +class PositionalEncoding(nn.Module): + """Positional encoding using sine and cosine functions.""" + + def __init__(self, d_model: int, max_seq_len: int): + super().__init__() + encoding = torch.zeros(max_seq_len, d_model) + position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) + ) + + encoding[:, 0::2] = torch.sin(position * div_term) + encoding[:, 1::2] = torch.cos(position * div_term) + encoding = encoding.unsqueeze(0) + + self.register_buffer("encoding", encoding, persistent=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Add positional encoding to input tensor.""" + return x + self.encoding[:, : x.size(1)] + + +class NicheformerPreTrainedModel(PreTrainedModel): + """Base class for Nicheformer models.""" + + config_class = NicheformerConfig + base_model_prefix = "nicheformer" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + nn.init.xavier_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +class NicheformerModel(NicheformerPreTrainedModel): + def __init__(self, config: NicheformerConfig): + super().__init__(config) + + # Core transformer components + self.encoder_layer = nn.TransformerEncoderLayer( + d_model=config.dim_model, + nhead=config.nheads, + dim_feedforward=config.dim_feedforward, + batch_first=config.batch_first, + dropout=config.dropout, + layer_norm_eps=1e-12, + ) + self.encoder = nn.TransformerEncoder( + encoder_layer=self.encoder_layer, + num_layers=config.nlayers, + enable_nested_tensor=False, + ) + + # Embedding layers + self.embeddings = nn.Embedding( + num_embeddings=config.n_tokens + 5, + embedding_dim=config.dim_model, + padding_idx=1, + ) + + if config.learnable_pe: + self.positional_embedding = nn.Embedding( + num_embeddings=config.context_length, embedding_dim=config.dim_model + ) + self.dropout = nn.Dropout(p=config.dropout) + self.register_buffer( + "pos", torch.arange(0, config.context_length, dtype=torch.long) + ) + else: + self.positional_embedding = PositionalEncoding( + d_model=config.dim_model, max_seq_len=config.context_length + ) + + # Initialize weights + self.post_init() + + def forward(self, input_ids, attention_mask=None): + token_embedding = self.embeddings(input_ids) + + if self.config.learnable_pe: + pos_embedding = self.positional_embedding( + self.pos.to(token_embedding.device) + ) + embeddings = self.dropout(token_embedding + pos_embedding) + else: + embeddings = self.positional_embedding(token_embedding) + + # Convert attention_mask to boolean and invert it for transformer's src_key_padding_mask + # True indicates positions that will be masked + if attention_mask is not None: + attention_mask = ~attention_mask.bool() + + transformer_output = self.encoder( + embeddings, + src_key_padding_mask=attention_mask if attention_mask is not None else None, + is_causal=False, + ) + + return transformer_output + + def get_embeddings( + self, + input_ids, + attention_mask=None, + layer: int = -1, + with_context: bool = False, + ) -> torch.Tensor: + """Get embeddings from the model. + + Args: + input_ids: Input token IDs + attention_mask: Attention mask + layer: Which transformer layer to extract embeddings from (-1 means last layer) + with_context: Whether to include context tokens in the embeddings + + Returns: + torch.Tensor: Embeddings tensor + """ + # Get token embeddings and positional encodings + token_embedding = self.embeddings(input_ids) + + if self.config.learnable_pe: + pos_embedding = self.positional_embedding( + self.pos.to(token_embedding.device) + ) + embeddings = self.dropout(token_embedding + pos_embedding) + else: + embeddings = self.positional_embedding(token_embedding) + + # Process through transformer layers up to desired layer + if layer < 0: + layer = self.config.nlayers + layer # -1 means last layer + + # Convert attention_mask to boolean and invert it for transformer's src_key_padding_mask + if attention_mask is not None: + padding_mask = ~attention_mask.bool() + else: + padding_mask = None + + # Process through each layer up to the desired one + for i in range(layer + 1): + embeddings = self.encoder.layers[i]( + embeddings, src_key_padding_mask=padding_mask, is_causal=False + ) + + # Remove context tokens (first 3 tokens) if not needed + if not with_context: + embeddings = embeddings[:, 3:, :] + + # Mean pooling over sequence dimension + embeddings = embeddings.mean(dim=1) + + return embeddings + + +class NicheformerForMaskedLM(NicheformerPreTrainedModel): + def __init__(self, config: NicheformerConfig): + super().__init__(config) + + self.nicheformer = NicheformerModel(config) + self.classifier_head = nn.Linear(config.dim_model, config.n_tokens, bias=False) + self.classifier_head.bias = nn.Parameter(torch.zeros(config.n_tokens)) + + # Initialize weights + self.post_init() + + def forward( + self, + input_ids=None, + attention_mask=None, + labels=None, + return_dict=None, + apply_masking=False, + ): + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # Apply masking if requested (typically during training) + if apply_masking: + batch = {"input_ids": input_ids, "attention_mask": attention_mask} + masked_batch = complete_masking( + batch, self.config.masking_p, self.config.n_tokens + ) + input_ids = masked_batch["masked_indices"] + labels = masked_batch["input_ids"] # Original tokens become labels + mask = masked_batch["mask"] + # Only compute loss on masked tokens and ensure labels are long + labels = torch.where( + mask, labels, torch.tensor(-100, device=labels.device) + ).long() + + transformer_output = self.nicheformer( + input_ids=input_ids, + attention_mask=attention_mask, + ) + + prediction_scores = self.classifier_head(transformer_output) + + masked_lm_loss = None + if labels is not None: + loss_fct = nn.CrossEntropyLoss() + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.n_tokens), labels.view(-1) + ) + + if not return_dict: + output = (prediction_scores,) + (transformer_output,) + return ( + ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + ) + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=transformer_output, + ) + + def get_embeddings( + self, + input_ids, + attention_mask=None, + layer: int = -1, + with_context: bool = False, + ) -> torch.Tensor: + """Get embeddings from the model. + + Args: + input_ids: Input token IDs + attention_mask: Attention mask + layer: Which transformer layer to extract embeddings from (-1 means last layer) + with_context: Whether to include context tokens in the embeddings + + Returns: + torch.Tensor: Embeddings tensor + """ + return self.nicheformer.get_embeddings( + input_ids=input_ids, + attention_mask=attention_mask, + layer=layer, + with_context=with_context, + ) + + +def complete_masking(batch, masking_p, n_tokens): + """Apply masking to input batch for masked language modeling. + + Args: + batch (dict): Input batch containing 'input_ids' and 'attention_mask' + masking_p (float): Probability of masking a token + n_tokens (int): Total number of tokens in vocabulary + + Returns: + dict: Batch with masked indices and masking information + """ + device = batch["input_ids"].device + input_ids = batch["input_ids"] + attention_mask = batch["attention_mask"] + + # Create mask tensor (1 for tokens to be masked, 0 otherwise) + prob = torch.rand(input_ids.shape, device=device) + mask = (prob < masking_p) & (input_ids != PAD_TOKEN) & (input_ids != CLS_TOKEN) + + # For masked tokens: + # - 80% replace with MASK token + # - 10% replace with random token + # - 10% keep unchanged + masked_indices = input_ids.clone() + + # Calculate number of tokens to be masked + num_tokens_to_mask = mask.sum().item() + + # Determine which tokens get which type of masking + mask_mask = torch.rand(num_tokens_to_mask, device=device) < 0.8 + random_mask = (torch.rand(num_tokens_to_mask, device=device) < 0.5) & ~mask_mask + + # Apply MASK token (80% of masked tokens) + masked_indices[mask] = torch.where( + mask_mask, + torch.tensor(MASK_TOKEN, device=device, dtype=torch.long), + masked_indices[mask], + ) + + # Apply random tokens (10% of masked tokens) + random_tokens = torch.randint( + 3, + n_tokens, # Start from 3 to avoid special tokens + (random_mask.sum(),), + device=device, + dtype=torch.long, + ) + masked_indices[mask][random_mask] = random_tokens + + # 10% remain unchanged + + return { + "masked_indices": masked_indices, + "attention_mask": attention_mask, + "mask": mask, + "input_ids": input_ids, + } diff --git a/helical/models/nicheformer/nicheformer_config.py b/helical/models/nicheformer/nicheformer_config.py index 1d5073dd..364ecd31 100644 --- a/helical/models/nicheformer/nicheformer_config.py +++ b/helical/models/nicheformer/nicheformer_config.py @@ -14,11 +14,6 @@ "vocab.json", "model.safetensors", "model.h5ad", - "modeling_nicheformer.py", - "tokenization_nicheformer.py", - "configuration_nicheformer.py", - "masking.py", - "__init__.py", ] diff --git a/helical/models/nicheformer/tokenization_nicheformer.py b/helical/models/nicheformer/tokenization_nicheformer.py new file mode 100644 index 00000000..95ab3820 --- /dev/null +++ b/helical/models/nicheformer/tokenization_nicheformer.py @@ -0,0 +1,462 @@ +from typing import List, Dict, Optional, Union, Tuple +import numpy as np +from transformers import PreTrainedTokenizer +from dataclasses import dataclass +import torch +import anndata as ad +from scipy.sparse import issparse +import numba +import os +import json +from huggingface_hub import hf_hub_download +import pandas as pd + +# Token IDs must match exactly with the original implementation +PAD_TOKEN = 0 +MASK_TOKEN = 1 +CLS_TOKEN = 2 + +# These mappings preserve the exact token IDs from the original implementation +MODALITY_DICT = { + "dissociated": 3, + "spatial": 4, +} + +SPECIES_DICT = { + "human": 5, + "Homo sapiens": 5, + "Mus musculus": 6, + "mouse": 6, +} + +TECHNOLOGY_DICT = { + "merfish": 7, + "MERFISH": 7, + "cosmx": 8, + "NanoString digital spatial profiling": 8, + "Xenium": 9, + "10x 5' v2": 10, + "10x 3' v3": 11, + "10x 3' v2": 12, + "10x 5' v1": 13, + "10x 3' v1": 14, + "10x 3' transcription profiling": 15, + "10x transcription profiling": 15, + "10x 5' transcription profiling": 16, + "CITE-seq": 17, + "Smart-seq v4": 18, +} + + +def sf_normalize(X: np.ndarray) -> np.ndarray: + """Size factor normalize to 10k counts.""" + X = X.copy() + counts = np.array(X.sum(axis=1)) + # avoid zero division error + counts += counts == 0.0 + # normalize to 10000 counts + scaling_factor = 10000.0 / counts + + if issparse(X): + from scipy.sparse import sparsefuncs + + sparsefuncs.inplace_row_scale(X, scaling_factor) + else: + np.multiply(X, scaling_factor.reshape((-1, 1)), out=X) + + return X + + +@numba.jit(nopython=True, nogil=True) +def _sub_tokenize_data( + x: np.ndarray, max_seq_len: int = -1, aux_tokens: int = 30 +) -> np.ndarray: + """Tokenize the input gene vector.""" + scores_final = np.empty( + (x.shape[0], max_seq_len if max_seq_len > 0 else x.shape[1]) + ) + for i, cell in enumerate(x): + nonzero_mask = np.nonzero(cell)[0] + sorted_indices = nonzero_mask[np.argsort(-cell[nonzero_mask])][:max_seq_len] + sorted_indices = sorted_indices + aux_tokens + if max_seq_len: + scores = np.zeros(max_seq_len, dtype=np.int32) + else: + scores = np.zeros_like(cell, dtype=np.int32) + scores[: len(sorted_indices)] = sorted_indices.astype(np.int32) + scores_final[i, :] = scores + return scores_final + + +class NicheformerTokenizer(PreTrainedTokenizer): + """Tokenizer for Nicheformer that handles single-cell data.""" + + model_input_names = ["input_ids", "attention_mask"] + vocab_files_names = {"vocab_file": "vocab.json"} + + modality_dict = MODALITY_DICT + species_dict = SPECIES_DICT + technology_dict = TECHNOLOGY_DICT + + def _load_reference_model(self): + """Load reference model for gene alignment.""" + try: + # Get the model name or path from the tokenizer + repo_id = ( + self.name_or_path + if hasattr(self, "name_or_path") + else "aletlvl/Nicheformer" + ) + + # Download the reference model if not already cached + model_path = hf_hub_download(repo_id=repo_id, filename="model.h5ad") + return ad.read_h5ad(model_path) + except Exception as e: + print(f"Warning: Could not load reference model: {e}") + return None + + def __init__( + self, + vocab_file=None, + max_length: int = 1500, + aux_tokens: int = 30, + median_counts_per_gene: Optional[np.ndarray] = None, + gene_names: Optional[List[str]] = None, + technology_mean: Optional[Union[str, np.ndarray]] = None, + **kwargs, + ): + # Initialize base vocabulary + self._vocabulary = { + "[PAD]": PAD_TOKEN, + "[MASK]": MASK_TOKEN, + "[CLS]": CLS_TOKEN, + } + + if vocab_file is not None: + with open(vocab_file, "r") as f: + self._vocabulary.update(json.load(f)) + else: + # Add modality tokens + for name, idx in self.modality_dict.items(): + self._vocabulary[f"[MODALITY_{name}]"] = idx + # Add species tokens + for name, idx in self.species_dict.items(): + if name in ["Homo sapiens", "Mus musculus"]: + continue # Skip redundant names + self._vocabulary[f"[SPECIES_{name}]"] = idx + # Add technology tokens + for name, idx in self.technology_dict.items(): + if name in ["MERFISH", "10x transcription profiling"]: + continue # Skip redundant names + clean_name = name.lower().replace(" ", "_").replace("'", "_") + self._vocabulary[f"[TECH_{clean_name}]"] = idx + + # Add gene tokens if provided + if gene_names is not None: + for i, gene in enumerate(gene_names): + self._vocabulary[gene] = i + aux_tokens + # Save vocabulary + os.makedirs("to_hf", exist_ok=True) + with open("to_hf/vocab.json", "w") as f: + json.dump(self._vocabulary, f, indent=4) + + super().__init__(**kwargs) + + self.max_length = max_length + self.aux_tokens = aux_tokens + self.median_counts_per_gene = median_counts_per_gene + self.gene_names = gene_names + self.name_or_path = kwargs.get("name_or_path", "aletlvl/Nicheformer") + + # Set up special token mappings + self._pad_token = "[PAD]" + self._mask_token = "[MASK]" + self._cls_token = "[CLS]" + + # Load technology mean if provided + self.technology_mean = None + if technology_mean is not None: + self._load_technology_mean(technology_mean) + + def _load_technology_mean(self, technology_mean): + """Load technology mean from file or array.""" + if isinstance(technology_mean, str): + try: + self.technology_mean = np.load(technology_mean) + print( + f"Loaded technology mean from {technology_mean} with shape {self.technology_mean.shape}" + ) + except Exception as e: + print( + f"Warning: Could not load technology mean from {technology_mean}: {e}" + ) + elif isinstance(technology_mean, np.ndarray): + self.technology_mean = technology_mean + print( + f"Using provided technology mean array with shape {self.technology_mean.shape}" + ) + else: + print(f"Warning: Invalid technology_mean type: {type(technology_mean)}") + + def get_vocab(self) -> Dict[str, int]: + """Returns the vocabulary mapping.""" + return self._vocabulary.copy() + + def _tokenize(self, text: str) -> List[str]: + """Tokenize text input.""" + # This tokenizer doesn't handle text input directly + raise NotImplementedError("This tokenizer only works with gene expression data") + + def _convert_token_to_id(self, token: str) -> int: + """Convert token to ID.""" + # First check special token mappings + if token in self.modality_dict: + return self.modality_dict[token] + if token in self.species_dict: + return self.species_dict[token] + if token in self.technology_dict: + return self.technology_dict[token] + # Then check vocabulary + return self._vocabulary.get(token, self._vocabulary["[PAD]"]) + + def _convert_id_to_token(self, index: int) -> str: + """Convert ID to token.""" + # First check special token mappings + for token, idx in self.modality_dict.items(): + if idx == index: + return token + for token, idx in self.species_dict.items(): + if idx == index: + return token + for token, idx in self.technology_dict.items(): + if idx == index: + return token + # Then check vocabulary + for token, idx in self._vocabulary.items(): + if idx == index: + return token + return "[PAD]" + + def save_vocabulary( + self, save_directory: str, filename_prefix: Optional[str] = None + ) -> Tuple[str]: + """Save the vocabulary to a file.""" + vocab_file = os.path.join( + save_directory, + (filename_prefix + "-" if filename_prefix else "") + "vocab.json", + ) + + with open(vocab_file, "w", encoding="utf-8") as f: + json.dump(self._vocabulary, f, ensure_ascii=False) + + return (vocab_file,) + + def _tokenize_gene_expression(self, x: np.ndarray) -> np.ndarray: + """Tokenize gene expression matrix. + + Args: + x: Gene expression matrix (cells x genes) + + Returns: + Tokenized matrix + """ + # Handle sparse input + if issparse(x): + x = x.toarray() + + # Normalize and scale + x = np.nan_to_num(x) + x = sf_normalize(x) + if self.median_counts_per_gene is not None: + median_counts = self.median_counts_per_gene.copy() + median_counts += median_counts == 0 + x = x / median_counts.reshape((1, -1)) + + # Apply technology mean normalization if available + if ( + self.technology_mean is not None + and self.technology_mean.shape[0] == x.shape[1] + ): + # Avoid division by zero + safe_mean = np.maximum(self.technology_mean, 1e-6) + x = x / safe_mean + + # Apply log1p transformation + x = np.log1p(x) + + # Convert to tokens + tokens = _sub_tokenize_data(x, self.max_length, self.aux_tokens) + + return tokens.astype(np.int32) + + def __call__( + self, data: Union[ad.AnnData, np.ndarray], **kwargs + ) -> Dict[str, torch.Tensor]: + """Tokenize gene expression data. + + Args: + data: AnnData object or numpy array of gene expression data + + Returns: + Dictionary with input_ids and attention_mask tensors + """ + if isinstance(data, ad.AnnData): + adata = data.copy() + + # Align with reference model if available + if hasattr(self, "_load_reference_model"): + reference_model = self._load_reference_model() + if reference_model is not None: + # Store original column types before concatenation + original_types = {} + for col in ["modality", "specie", "assay"]: + if col in adata.obs.columns: + original_types[col] = adata.obs[col].dtype + + # Concatenate and then remove the reference + adata = ad.concat([reference_model, adata], join="outer", axis=0) + adata = adata[1:] + + # Restore original column types after concatenation + for col, dtype in original_types.items(): + if col in adata.obs.columns: + try: + adata.obs[col] = adata.obs[col].astype(dtype) + except Exception as e: + print( + f"Warning: Could not convert {col} back to {dtype}: {e}" + ) + + # Get gene expression data + X = adata.X + + # Get metadata for special tokens + modality = ( + adata.obs["modality"] if "modality" in adata.obs.columns else None + ) + species = adata.obs["specie"] if "specie" in adata.obs.columns else None + technology = adata.obs["assay"] if "assay" in adata.obs.columns else None + + # Use integer values directly if available + if modality is not None: + try: + if pd.api.types.is_numeric_dtype(modality): + modality_tokens = modality.astype(int).tolist() + else: + modality_tokens = [ + self.modality_dict.get(m, self._vocabulary["[PAD]"]) + for m in modality + ] + except Exception as e: + print(f"Warning: Error processing modality tokens: {e}") + modality_tokens = [self._vocabulary["[PAD]"]] * len(adata) + else: + modality_tokens = None + + if species is not None: + try: + if pd.api.types.is_numeric_dtype(species): + species_tokens = species.astype(int).tolist() + else: + species_tokens = [ + self.species_dict.get(s, self._vocabulary["[PAD]"]) + for s in species + ] + except Exception as e: + print(f"Warning: Error processing species tokens: {e}") + species_tokens = [self._vocabulary["[PAD]"]] * len(adata) + else: + species_tokens = None + + if technology is not None: + try: + if pd.api.types.is_numeric_dtype(technology): + technology_tokens = technology.astype(int).tolist() + else: + technology_tokens = [ + self.technology_dict.get(t, self._vocabulary["[PAD]"]) + for t in technology + ] + except Exception as e: + print(f"Warning: Error processing technology tokens: {e}") + technology_tokens = [self._vocabulary["[PAD]"]] * len(adata) + else: + technology_tokens = None + else: + X = data + modality_tokens = None + species_tokens = None + technology_tokens = None + + # Tokenize gene expression data + token_ids = self._tokenize_gene_expression(X) + + # Add special tokens if available - changed order to [species, technology, modality] + special_tokens = np.zeros((token_ids.shape[0], 3), dtype=np.int64) + special_token_mask = np.zeros((token_ids.shape[0], 3), dtype=bool) + + if species_tokens is not None: + special_tokens[:, 0] = species_tokens + special_token_mask[:, 0] = True + + if technology_tokens is not None: + special_tokens[:, 1] = technology_tokens + special_token_mask[:, 1] = True + + if modality_tokens is not None: + special_tokens[:, 2] = modality_tokens + special_token_mask[:, 2] = True + + # Only keep the special tokens that are present (have True in mask) + special_tokens = special_tokens[:, special_token_mask[0]] + + if special_tokens.size > 0: + token_ids = np.concatenate( + [ + special_tokens, + token_ids[:, : (self.max_length - special_tokens.shape[1])], + ], + axis=1, + ) + + # Create attention mask + attention_mask = token_ids != self._vocabulary["[PAD]"] + + return { + "input_ids": torch.tensor(token_ids, dtype=torch.long), + "attention_mask": torch.tensor(attention_mask), + } + + def get_vocab_size(self) -> int: + """Get vocabulary size.""" + if self.gene_names is not None: + return len(self.gene_names) + self.aux_tokens + return ( + max( + max(self.modality_dict.values()), + max(self.species_dict.values()), + max(self.technology_dict.values()), + ) + + 1 + ) + + def convert_tokens_to_string(self, tokens: List[str]) -> str: + """Convert a sequence of tokens to a string. Not used for gene expression.""" + raise NotImplementedError("This tokenizer only works with gene expression data") + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """Build model inputs from a sequence by adding special tokens.""" + # For gene expression data, special tokens are handled in __call__ + return token_ids_0 + + def get_special_tokens_mask( + self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None, + already_has_special_tokens: bool = False, + ) -> List[int]: + """Get list where entries are [1] if a token is [special] else [0].""" + # Consider tokens < aux_tokens as special + return [1 if token_id < self.aux_tokens else 0 for token_id in token_ids_0] diff --git a/pyproject.toml b/pyproject.toml index db8e1b52..11c56a0e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "helical" -version = "1.11.0" +version = "1.12.0" authors = [ { name="Helical Team", email="support@helical-ai.com" }, ]